From 7eb719df13cf8059485f52648a6a115700158301 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 19 Nov 2024 11:21:42 +0800 Subject: [PATCH] [Bugfix]Fix Phi-3 BNB online quantization (#10417) Signed-off-by: Jee Jee Li --- vllm/model_executor/layers/linear.py | 12 +++++++++--- vllm/model_executor/models/phi3.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e1f8a6e36d781..9da38d4857d6d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -470,7 +470,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( @@ -480,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data.copy_(loaded_weight) return current_shard_offset = 0 + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) shard_offsets: List[Tuple[int, int, int]] = [] for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) @@ -495,7 +498,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - + if use_bitsandbytes_4bit: + shard_size = loaded_weight.shape[output_dim] // 2 + shard_offset = shard_size * shard_id loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -808,7 +813,8 @@ class QKVParallelLinear(ColumnParallelLinear): needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py index 34141511ea791..54158bc141235 100644 --- a/vllm/model_executor/models/phi3.py +++ b/vllm/model_executor/models/phi3.py @@ -14,3 +14,13 @@ class Phi3ForCausalLM(LlamaForCausalLM): "gate_up_proj", ], } + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_up_proj.", + ".down_proj.", + ".qkv_proj.", + ".o_proj.", + ] + # Initialize an empty dict when there is no stacked parameter mapping. + bitsandbytes_stacked_params_mapping = {}