mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 07:45:01 +08:00
[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
762be26a8e
commit
8020e98c9f
@ -14,7 +14,7 @@ from transformers import BitsAndBytesConfig
|
|||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
|
||||||
from ...utils import compare_two_settings, multi_gpu_test
|
from ...utils import compare_two_settings, multi_gpu_test
|
||||||
from ..utils import check_embeddings_close
|
from ..utils import check_embeddings_close, check_logprobs_close
|
||||||
|
|
||||||
models_4bit_to_test = [
|
models_4bit_to_test = [
|
||||||
("facebook/opt-125m", "quantize opt model inflight"),
|
("facebook/opt-125m", "quantize opt model inflight"),
|
||||||
@ -26,6 +26,10 @@ models_4bit_to_embedding_test = [
|
|||||||
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
|
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
models_4bit_to_moe_test = [
|
||||||
|
("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"),
|
||||||
|
]
|
||||||
|
|
||||||
models_pre_qaunt_4bit_to_test = [
|
models_pre_qaunt_4bit_to_test = [
|
||||||
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
||||||
'read pre-quantized 4-bit FP4 model'),
|
'read pre-quantized 4-bit FP4 model'),
|
||||||
@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
|
|||||||
compare_two_settings(model_name, common_args, pp_args)
|
compare_two_settings(model_name, common_args, pp_args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||||
|
reason='bitsandbytes is not supported on this GPU type.')
|
||||||
|
@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test)
|
||||||
|
def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts,
|
||||||
|
model_name, description) -> None:
|
||||||
|
|
||||||
|
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
))
|
||||||
|
with vllm_runner(model_name,
|
||||||
|
quantization='bitsandbytes',
|
||||||
|
enforce_eager=False) as llm:
|
||||||
|
vllm_outputs = llm.generate_greedy_logprobs(example_prompts,
|
||||||
|
max_tokens=32,
|
||||||
|
num_logprobs=5)
|
||||||
|
|
||||||
|
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||||
|
transformers_outputs = llm.generate_greedy_logprobs_limit(
|
||||||
|
example_prompts, max_tokens=32, num_logprobs=5)
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=transformers_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="transformers",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||||
reason='bitsandbytes is not supported on this GPU type.')
|
reason='bitsandbytes is not supported on this GPU type.')
|
||||||
@pytest.mark.parametrize("model_name, description",
|
@pytest.mark.parametrize("model_name, description",
|
||||||
@ -182,7 +215,8 @@ def validate_generated_texts(hf_runner,
|
|||||||
model_name,
|
model_name,
|
||||||
pre_quant=False,
|
pre_quant=False,
|
||||||
hf_model_kwargs=None,
|
hf_model_kwargs=None,
|
||||||
vllm_tp_size=1):
|
vllm_tp_size=1,
|
||||||
|
max_tokens=8):
|
||||||
|
|
||||||
# NOTE: run vLLM first, as it requires a clean process
|
# NOTE: run vLLM first, as it requires a clean process
|
||||||
# when using distributed inference
|
# when using distributed inference
|
||||||
@ -190,7 +224,8 @@ def validate_generated_texts(hf_runner,
|
|||||||
quantization=None if pre_quant else 'bitsandbytes',
|
quantization=None if pre_quant else 'bitsandbytes',
|
||||||
tensor_parallel_size=vllm_tp_size,
|
tensor_parallel_size=vllm_tp_size,
|
||||||
enforce_eager=False) as llm:
|
enforce_eager=False) as llm:
|
||||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
|
||||||
|
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
|
||||||
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
||||||
|
|
||||||
# Clean up the GPU memory for the next test
|
# Clean up the GPU memory for the next test
|
||||||
@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
|
|||||||
|
|
||||||
# Run with HF runner
|
# Run with HF runner
|
||||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||||
hf_outputs = llm.generate_greedy(prompts, 8)
|
hf_outputs = llm.generate_greedy(prompts, max_tokens)
|
||||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||||
|
|
||||||
# Clean up the GPU memory for the next test
|
# Clean up the GPU memory for the next test
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Compare the generated strings
|
# Compare the generated strings
|
||||||
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
|
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
|
||||||
hf_str = hf_log["generated_text"]
|
hf_str = hf_log["generated_text"]
|
||||||
vllm_str = vllm_log["generated_text"]
|
vllm_str = vllm_log["generated_text"]
|
||||||
prompt = hf_log["prompt"]
|
prompt = hf_log["prompt"]
|
||||||
|
|
||||||
assert hf_str == vllm_str, (f"Model: {model_name}"
|
assert hf_str == vllm_str, (f"Model: {model_name}"
|
||||||
f"Mismatch between HF and vLLM outputs:\n"
|
f"Mismatch between HF and vLLM outputs:\n"
|
||||||
f"Prompt: {prompt}\n"
|
f"Prompt: {prompt}\n"
|
||||||
|
|||||||
@ -883,14 +883,21 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_data=expert_data,
|
expert_data=expert_data,
|
||||||
tp_rank=tp_rank)
|
tp_rank=tp_rank)
|
||||||
|
|
||||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
def _load_w13(self,
|
||||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
expert_data: torch.Tensor,
|
||||||
|
shard_dim: int,
|
||||||
|
shard_id: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
load_full: bool = False):
|
||||||
|
|
||||||
# Index the loaded weight for tp sharding.
|
# Index the loaded weight for tp sharding.
|
||||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||||
shard_size = expert_data.shape[shard_dim] // 2
|
shard_size = expert_data.shape[shard_dim] // 2
|
||||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
if not load_full:
|
||||||
shard_size)
|
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||||
|
shard_size * tp_rank,
|
||||||
|
shard_size)
|
||||||
# Narrow parameter and load.
|
# Narrow parameter and load.
|
||||||
# w1, gate_proj: Load into first logical weight of w13.
|
# w1, gate_proj: Load into first logical weight of w13.
|
||||||
if shard_id == "w1":
|
if shard_id == "w1":
|
||||||
@ -998,6 +1005,27 @@ class FusedMoE(torch.nn.Module):
|
|||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
return True if return_success else None
|
return True if return_success else None
|
||||||
|
|
||||||
|
# Case for BitsAndBytes
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
|
if use_bitsandbytes_4bit:
|
||||||
|
shard_dim = 0
|
||||||
|
|
||||||
|
expert_data = param.data[expert_id]
|
||||||
|
if shard_id == "w2":
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
elif shard_id in ("w1", "w3"):
|
||||||
|
# BNB inflight quantization has already sharded the weights
|
||||||
|
full_load = True
|
||||||
|
self._load_w13(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
load_full=full_load,
|
||||||
|
)
|
||||||
|
return True if return_success else None
|
||||||
|
|
||||||
# is_transposed: if the dim to shard the weight
|
# is_transposed: if the dim to shard the weight
|
||||||
# should be flipped. Required by GPTQ, compressed-tensors
|
# should be flipped. Required by GPTQ, compressed-tensors
|
||||||
# should be whatever dimension intermediate_size_per_partition is
|
# should be whatever dimension intermediate_size_per_partition is
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig):
|
|||||||
llm_int8_skip_modules=llm_int8_skip_modules,
|
llm_int8_skip_modules=llm_int8_skip_modules,
|
||||||
llm_int8_threshold=llm_int8_threshold)
|
llm_int8_threshold=llm_int8_threshold)
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(
|
||||||
prefix: str) -> Optional["LinearMethodBase"]:
|
self, layer: torch.nn.Module, prefix: str
|
||||||
|
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return BitsAndBytesLinearMethod(self)
|
return BitsAndBytesLinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return BitsAndBytesMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
|||||||
return substr_check or prefix_check
|
return substr_check or prefix_check
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_quant_ratio(dtype):
|
||||||
|
if dtype.is_floating_point:
|
||||||
|
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||||
|
else:
|
||||||
|
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||||
|
|
||||||
|
|
||||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||||
"""Linear method for BitsAndBytes.
|
"""Linear method for BitsAndBytes.
|
||||||
|
|
||||||
@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
**extra_weight_attrs):
|
**extra_weight_attrs):
|
||||||
from bitsandbytes.nn import Int8Params
|
from bitsandbytes.nn import Int8Params
|
||||||
|
|
||||||
def calculate_quant_ratio(dtype):
|
|
||||||
if dtype.is_floating_point:
|
|
||||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
|
||||||
else:
|
|
||||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
|
||||||
|
|
||||||
def create_qweight_for_8bit():
|
def create_qweight_for_8bit():
|
||||||
qweight = Int8Params(
|
qweight = Int8Params(
|
||||||
data=torch.empty(sum(output_partition_sizes),
|
data=torch.empty(sum(output_partition_sizes),
|
||||||
@ -394,3 +401,210 @@ try:
|
|||||||
|
|
||||||
except AttributeError as error:
|
except AttributeError as error:
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
|
||||||
|
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||||
|
"""MoE method for BitsAndBytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The BitsAndBytes quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||||
|
try:
|
||||||
|
import bitsandbytes
|
||||||
|
if bitsandbytes.__version__ < "0.45.3":
|
||||||
|
raise ImportError("bitsandbytes version is wrong. Please "
|
||||||
|
"install bitsandbytes>=0.45.3.")
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImportError("Please install bitsandbytes>=0.45.3 via "
|
||||||
|
"`pip install bitsandbytes>=0.45.3` to use "
|
||||||
|
"bitsandbytes quantizer.") from err
|
||||||
|
self.topk_indices_dtype = None
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
if self.quant_config.load_in_8bit:
|
||||||
|
call_fun = self._create_weights_8bit
|
||||||
|
else:
|
||||||
|
call_fun = self._create_weights_4bit
|
||||||
|
call_fun(
|
||||||
|
layer,
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
params_dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if enable_eplb:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
if self.quant_config.load_in_8bit:
|
||||||
|
w13, w2 = self._apply_8bit_dequant(layer)
|
||||||
|
else:
|
||||||
|
w13, w2 = self._apply_4bit_dequnt(layer)
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=w13,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_weights_4bit(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_total_size = (hidden_size * 2 *
|
||||||
|
intermediate_size_per_partition) // quant_ratio
|
||||||
|
w13_qweight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
w13_total_size,
|
||||||
|
1,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_qweight)
|
||||||
|
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||||
|
set_weight_attrs(
|
||||||
|
w13_qweight,
|
||||||
|
{
|
||||||
|
"num_experts":
|
||||||
|
num_experts,
|
||||||
|
"input_dim":
|
||||||
|
hidden_size,
|
||||||
|
"output_dim":
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
"experts_shape": (
|
||||||
|
num_experts,
|
||||||
|
intermediate_size_per_partition * 2,
|
||||||
|
hidden_size,
|
||||||
|
),
|
||||||
|
"pack_factor":
|
||||||
|
quant_ratio,
|
||||||
|
"use_bitsandbytes_4bit":
|
||||||
|
True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_total_size = (hidden_size *
|
||||||
|
intermediate_size_per_partition) // quant_ratio
|
||||||
|
w2_qweight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
w2_total_size,
|
||||||
|
1,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
w2_qweight,
|
||||||
|
{
|
||||||
|
"num_experts":
|
||||||
|
num_experts,
|
||||||
|
"input_dim":
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
"output_dim":
|
||||||
|
hidden_size,
|
||||||
|
"experts_shape": (
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
),
|
||||||
|
"pack_factor":
|
||||||
|
quant_ratio,
|
||||||
|
"use_bitsandbytes_4bit":
|
||||||
|
True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_qweight)
|
||||||
|
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def _create_weights_8bit(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _apply_4bit_dequnt(
|
||||||
|
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from bitsandbytes.functional import dequantize_4bit
|
||||||
|
w13 = dequantize_4bit(
|
||||||
|
layer.w13_weight.reshape(-1, 1),
|
||||||
|
layer.w13_weight.bnb_quant_state,
|
||||||
|
)
|
||||||
|
w2 = dequantize_4bit(
|
||||||
|
layer.w2_weight.reshape(-1, 1),
|
||||||
|
layer.w2_weight.bnb_quant_state,
|
||||||
|
)
|
||||||
|
w13 = w13.reshape(layer.w13_weight.experts_shape)
|
||||||
|
w2 = w2.reshape(layer.w2_weight.experts_shape)
|
||||||
|
return w13, w2
|
||||||
|
|
||||||
|
def _apply_8bit_dequant(
|
||||||
|
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# in case model has a mixture of disk-merged and disk-split
|
# in case model has a mixture of disk-merged and disk-split
|
||||||
# weights with same last name.
|
# weights with same last name.
|
||||||
self.target_modules.append(name)
|
self.target_modules.append(name)
|
||||||
|
elif (isinstance(module, FusedMoE)
|
||||||
|
and hasattr(module.quant_method, "quant_config")):
|
||||||
|
if not hasattr(model, "get_expert_mapping"):
|
||||||
|
raise AttributeError(
|
||||||
|
f"MoE Model {type(model).__name__} does not support "
|
||||||
|
"BitsAndBytes quantization yet. Ensure this model has "
|
||||||
|
"'get_expert_mapping' method.")
|
||||||
|
# TODO: support FusedMoE with prequant and 8bit.
|
||||||
|
if self.pre_quant:
|
||||||
|
raise ValueError(
|
||||||
|
"Prequant BitsAndBytes models with FusedMoE is not "
|
||||||
|
"supported yet.")
|
||||||
|
if self.load_8bit:
|
||||||
|
raise ValueError(
|
||||||
|
"BitsAndBytes 8bit quantization with FusedMoE is not "
|
||||||
|
"supported yet.")
|
||||||
|
# Get the corresponding weight name using module name and
|
||||||
|
# get_expert_mapping.
|
||||||
|
expert_mapping = model.get_expert_mapping()
|
||||||
|
for exp in expert_mapping:
|
||||||
|
weight_name = exp[1]
|
||||||
|
rep_name = name.replace("experts",
|
||||||
|
"") + weight_name.removesuffix(".")
|
||||||
|
self.target_modules.append(rep_name)
|
||||||
|
|
||||||
assert (self.target_modules
|
assert (self.target_modules
|
||||||
), "vllm currently does not support BNB quantization for"
|
), "vLLM currently does not support BNB quantization for"
|
||||||
f" {type(model).__name__}"
|
f" {type(model).__name__}"
|
||||||
|
|
||||||
def _classify_module_sharding(self, model: nn.Module):
|
def _classify_module_sharding(self, model: nn.Module):
|
||||||
@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# dimension (dim=-1)
|
# dimension (dim=-1)
|
||||||
elif isinstance(module, (RowParallelLinear, )):
|
elif isinstance(module, (RowParallelLinear, )):
|
||||||
self.column_sharded_weights_modules.append(name)
|
self.column_sharded_weights_modules.append(name)
|
||||||
|
elif isinstance(module, FusedMoE):
|
||||||
|
expert_mapping = model.get_expert_mapping()
|
||||||
|
for exp in expert_mapping:
|
||||||
|
if exp[-1] == "w2":
|
||||||
|
weight_name = exp[1]
|
||||||
|
rep_name = name.replace(
|
||||||
|
"experts", "") + weight_name.removesuffix(".")
|
||||||
|
self.column_sharded_weights_modules.append(rep_name)
|
||||||
|
|
||||||
def _verify_model_compatibility(self, model: nn.Module,
|
def _verify_model_compatibility(self, model: nn.Module,
|
||||||
model_config: ModelConfig) -> None:
|
model_config: ModelConfig) -> None:
|
||||||
@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
self._get_bnb_target_modules(model)
|
self._get_bnb_target_modules(model)
|
||||||
self._classify_module_sharding(model)
|
self._classify_module_sharding(model)
|
||||||
|
|
||||||
def load_weights(self, model: nn.Module,
|
def _dequantize_dq(self, quant_states: Any):
|
||||||
model_config: ModelConfig) -> None:
|
"""
|
||||||
|
When BNB employs Double Quantization, we perform the dequantization of
|
||||||
|
these constants during weight loading rather than at inference time,
|
||||||
|
thereby avoiding this computational overhead during inference. This
|
||||||
|
comes at the cost of increased memory usage.
|
||||||
|
"""
|
||||||
|
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||||
|
|
||||||
self._verify_model_compatibility(model, model_config)
|
def _dequantize_single_state(quant_state):
|
||||||
self._initialize_loader_state(model, model_config)
|
"""Helper function to dequantize a single QuantState object."""
|
||||||
|
if not (isinstance(quant_state, QuantState)
|
||||||
|
and quant_state.nested):
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||||
"May take a while ...")
|
absmax = dequantize_blockwise(quant_state.absmax,
|
||||||
qweight_iterator, quant_state_dict = (
|
quant_state.state2)
|
||||||
self._get_quantized_weights_iterator(
|
absmax += quant_state.offset
|
||||||
model_config.model,
|
|
||||||
model_config.revision,
|
|
||||||
))
|
|
||||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
|
||||||
loaded_weights = model.load_weights(qweight_iterator)
|
|
||||||
# Some models may have weights loading tracker unimplemented.
|
|
||||||
if loaded_weights is not None:
|
|
||||||
weights_not_loaded = weights_to_load - loaded_weights
|
|
||||||
if weights_not_loaded:
|
|
||||||
raise ValueError("Following weights were not initialized from "
|
|
||||||
f"checkpoint: {weights_not_loaded}")
|
|
||||||
|
|
||||||
param_dict = dict(model.named_parameters())
|
# Ensure float32 dtype
|
||||||
|
if absmax.dtype != torch.float32:
|
||||||
|
absmax = absmax.float()
|
||||||
|
|
||||||
|
quant_state.absmax = absmax
|
||||||
|
quant_state.nested = False
|
||||||
|
quant_state.offset = None
|
||||||
|
quant_state.state2 = None
|
||||||
|
|
||||||
|
if isinstance(quant_states, dict):
|
||||||
|
for quant_state in quant_states.values():
|
||||||
|
_dequantize_single_state(quant_state)
|
||||||
|
else:
|
||||||
|
_dequantize_single_state(quant_states)
|
||||||
|
return quant_states
|
||||||
|
|
||||||
|
def _fuse_moe_quant_states(self, model: nn.Module,
|
||||||
|
quant_states_dict: dict) -> dict:
|
||||||
|
"""
|
||||||
|
|
||||||
|
This function consolidates individual expert quantization states into
|
||||||
|
fused representations for w13 and w2.
|
||||||
|
"""
|
||||||
|
from bitsandbytes.functional import QuantState
|
||||||
|
|
||||||
|
if not hasattr(model, "get_expert_mapping"):
|
||||||
|
return dict()
|
||||||
|
|
||||||
|
expert_mapping = model.get_expert_mapping()
|
||||||
|
expert_qs_dict = {}
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not isinstance(module, FusedMoE):
|
||||||
|
continue
|
||||||
|
w1_states_lst = []
|
||||||
|
w2_states_lst = []
|
||||||
|
w3_states_lst = []
|
||||||
|
for exp in expert_mapping:
|
||||||
|
shard_id = exp[-1]
|
||||||
|
if shard_id not in ("w1", "w2", "w3"):
|
||||||
|
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||||
|
f"got {shard_id}.")
|
||||||
|
layer_prefix = name.split("experts")[0]
|
||||||
|
weight_qual_name = layer_prefix + exp[1] + "weight"
|
||||||
|
quant_state = self._dequantize_dq(
|
||||||
|
quant_states_dict[weight_qual_name])
|
||||||
|
if shard_id == "w1":
|
||||||
|
w1_states_lst.append(quant_state)
|
||||||
|
elif shard_id == "w2":
|
||||||
|
w2_states_lst.append(quant_state)
|
||||||
|
else:
|
||||||
|
w3_states_lst.append(quant_state)
|
||||||
|
del quant_states_dict[weight_qual_name]
|
||||||
|
assert (len(w1_states_lst) == len(w2_states_lst) ==
|
||||||
|
len(w3_states_lst))
|
||||||
|
w13_absmax_lst = []
|
||||||
|
w2_absmax_lst = []
|
||||||
|
w13_total_dim0 = 0
|
||||||
|
w2_total_dim0 = 0
|
||||||
|
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst,
|
||||||
|
w3_states_lst):
|
||||||
|
assert w1_qs.shape == w3_qs.shape
|
||||||
|
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
|
||||||
|
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
|
||||||
|
# w1 and w3 are interleaved in storage
|
||||||
|
w13_absmax_lst.append(w1_qs.absmax)
|
||||||
|
w13_absmax_lst.append(w3_qs.absmax)
|
||||||
|
w2_absmax_lst.append(w2_qs.absmax)
|
||||||
|
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
|
||||||
|
w2_total_dim0 += w2_qs.shape[0]
|
||||||
|
|
||||||
|
w13_absmax = torch.cat(w13_absmax_lst)
|
||||||
|
w2_absmax = torch.cat(w2_absmax_lst)
|
||||||
|
# Create fused quantization state for w13.
|
||||||
|
w13_qs = QuantState(
|
||||||
|
absmax=w13_absmax,
|
||||||
|
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
|
||||||
|
code=w1_states_lst[0].code,
|
||||||
|
blocksize=w1_states_lst[0].blocksize,
|
||||||
|
quant_type="nf4",
|
||||||
|
dtype=w1_states_lst[0].dtype,
|
||||||
|
)
|
||||||
|
# Create fused quantization state for w2.
|
||||||
|
w2_qs = QuantState(
|
||||||
|
absmax=w2_absmax,
|
||||||
|
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
|
||||||
|
code=w2_states_lst[0].code,
|
||||||
|
blocksize=w2_states_lst[0].blocksize,
|
||||||
|
quant_type="nf4",
|
||||||
|
dtype=w2_states_lst[0].dtype,
|
||||||
|
)
|
||||||
|
# The weight suffixes .w13_weight and .w2_weight are consistent
|
||||||
|
# with the param in BitsAndBytesMoEMethod.
|
||||||
|
w13_weight_name = name + ".w13_weight"
|
||||||
|
w2_weight_name = name + ".w2_weight"
|
||||||
|
expert_qs_dict[w13_weight_name] = w13_qs
|
||||||
|
expert_qs_dict[w2_weight_name] = w2_qs
|
||||||
|
return expert_qs_dict
|
||||||
|
|
||||||
|
def _stack_quantization_states(
|
||||||
|
self, model: nn.Module,
|
||||||
|
quant_state_dict: dict) -> dict[str, dict[int, Any]]:
|
||||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||||
# TODO: Change this lazy import to normal import
|
# TODO: Change this lazy import to normal import
|
||||||
# after the checks are updated to run on a new version
|
# after the checks are updated to run on a new version
|
||||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||||
|
param_dict = dict(model.named_parameters())
|
||||||
for quant_param_name in quant_state_dict:
|
for quant_param_name in quant_state_dict:
|
||||||
if is_pp_missing_parameter(quant_param_name, model):
|
if is_pp_missing_parameter(quant_param_name, model):
|
||||||
continue
|
continue
|
||||||
@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
stacked_quant_state_dict[quant_param_name][shard_index] = (
|
stacked_quant_state_dict[quant_param_name][shard_index] = (
|
||||||
quant_state_dict[non_stacked_param_name])
|
quant_state_dict[non_stacked_param_name])
|
||||||
|
return stacked_quant_state_dict
|
||||||
|
|
||||||
|
def _bind_quant_states_to_params(self, model: nn.Module,
|
||||||
|
stacked_quant_state_dict: dict) -> None:
|
||||||
# save quant_states and offsets as the attributes of the parameters
|
# save quant_states and offsets as the attributes of the parameters
|
||||||
|
param_dict = dict(model.named_parameters())
|
||||||
for param_name, param in param_dict.items():
|
for param_name, param in param_dict.items():
|
||||||
if param_name in stacked_quant_state_dict:
|
if param_name in stacked_quant_state_dict:
|
||||||
quant_states = stacked_quant_state_dict[param_name]
|
quant_states = stacked_quant_state_dict[param_name]
|
||||||
# Dequantize double quantized values during weight loading.
|
# Dequantize double quantized values during weight loading.
|
||||||
dequantize_dq(quant_states)
|
self._dequantize_dq(quant_states)
|
||||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||||
|
if not isinstance(quant_states, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
pack_ratio = getattr(param, "pack_factor", -1)
|
pack_ratio = getattr(param, "pack_factor", -1)
|
||||||
if pack_ratio == -1:
|
if pack_ratio == -1:
|
||||||
@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if self.load_8bit:
|
if self.load_8bit:
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
param, {"matmul_state": [None] * len(quant_states)})
|
param, {"matmul_state": [None] * len(quant_states)})
|
||||||
|
|
||||||
|
def load_weights(self, model: nn.Module,
|
||||||
|
model_config: ModelConfig) -> None:
|
||||||
|
|
||||||
|
self._verify_model_compatibility(model, model_config)
|
||||||
|
self._initialize_loader_state(model, model_config)
|
||||||
|
|
||||||
|
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||||
|
"May take a while ...")
|
||||||
|
qweight_iterator, quant_state_dict = (
|
||||||
|
self._get_quantized_weights_iterator(
|
||||||
|
model_config.model,
|
||||||
|
model_config.revision,
|
||||||
|
))
|
||||||
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||||
|
loaded_weights = model.load_weights(qweight_iterator)
|
||||||
|
# Some models may have weights loading tracker unimplemented.
|
||||||
|
if loaded_weights is not None:
|
||||||
|
weights_not_loaded = weights_to_load - loaded_weights
|
||||||
|
if weights_not_loaded:
|
||||||
|
raise ValueError("Following weights were not initialized from "
|
||||||
|
f"checkpoint: {weights_not_loaded}")
|
||||||
|
expert_quant_state_dict = self._fuse_moe_quant_states(
|
||||||
|
model, quant_state_dict)
|
||||||
|
|
||||||
|
stacked_quant_state_dict = self._stack_quantization_states(
|
||||||
|
model, quant_state_dict)
|
||||||
|
|
||||||
|
stacked_quant_state_dict = {
|
||||||
|
**expert_quant_state_dict,
|
||||||
|
**stacked_quant_state_dict
|
||||||
|
}
|
||||||
|
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model, model_config.revision)
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
|
|
||||||
def dequantize_dq(quant_states: dict) -> None:
|
|
||||||
"""
|
|
||||||
When BNB employs Double Quantization, we perform the dequantization of
|
|
||||||
these constants during weight loading rather than at inference time,
|
|
||||||
thereby avoiding this computational overhead during inference. This comes
|
|
||||||
at the cost of increased memory usage.
|
|
||||||
"""
|
|
||||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
|
||||||
for _, quant_state in quant_states.items():
|
|
||||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
|
||||||
if isinstance(quant_state, QuantState) and quant_state.nested:
|
|
||||||
absmax = dequantize_blockwise(quant_state.absmax,
|
|
||||||
quant_state.state2)
|
|
||||||
absmax += quant_state.offset
|
|
||||||
if absmax.dtype != torch.float32:
|
|
||||||
absmax = absmax.float()
|
|
||||||
quant_state.absmax = absmax
|
|
||||||
quant_state.nested = False
|
|
||||||
quant_state.offset = None
|
|
||||||
quant_state.state2 = None
|
|
||||||
|
|||||||
@ -330,6 +330,15 @@ class OlmoeModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -341,14 +350,6 @@ class OlmoeModel(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.num_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@ -379,7 +380,7 @@ class OlmoeModel(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for mapping in expert_params_mapping:
|
for mapping in self.get_expert_mapping():
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@ -425,6 +426,17 @@ class OlmoeModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class OlmoeForCausalLM(nn.Module, SupportsPP):
|
class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -466,3 +478,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -516,6 +516,14 @@ class PhiMoEModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="w1",
|
||||||
|
ckpt_down_proj_name="w2",
|
||||||
|
ckpt_up_proj_name="w3",
|
||||||
|
num_experts=self.config.num_local_experts,
|
||||||
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -672,3 +680,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -391,6 +391,15 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -402,14 +411,6 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.num_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@ -441,11 +442,13 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for mapping in expert_params_mapping:
|
for mapping in self.get_expert_mapping():
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
if "layers.13.mlp.experts.w2_weight" in name:
|
||||||
|
pass
|
||||||
# Skip layers on other devices.
|
# Skip layers on other devices.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
@ -493,6 +496,17 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -538,3 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -375,6 +375,15 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -393,12 +402,7 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.num_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@ -539,3 +543,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user