[Bugfix] Fix Phi-3 BNB quantization with tensor parallel (#9948)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-22 15:01:56 +08:00 committed by GitHub
parent a111d0151f
commit b6374e09b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 6 deletions

View File

@ -1,3 +1,4 @@
import itertools
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
@ -41,12 +42,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
def adjust_bitsandbytes_4bit_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
shard_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"]
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
total, _ = shard_offsets["total"]
orig_offset, orig_size = shard_offsets[loaded_shard_id]
quantized_total = param.data.shape[0]
quantized_offset = orig_offset * quantized_total // total
@ -499,9 +500,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim] // 2
shard_offset = shard_size * shard_id
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size)
for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(shard_id))
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)

View File

@ -5,6 +5,7 @@ import dataclasses
import fnmatch
import glob
import inspect
import itertools
import json
import math
import os
@ -27,7 +28,9 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
@ -936,6 +939,34 @@ class BitsAndBytesModelLoader(BaseModelLoader):
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
weight_name.startswith(module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if weight_name.startswith(module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
total_start_index = list(
itertools.accumulate([0] + total_shard_sizes))[:-1]
shard_weights_index = [
(idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1))
for idx, size in zip(total_start_index,
total_shard_sizes)
]
# slice and reorder the weight tensor
weight_tensor = [
weight_tensor[start_index:end_index, ...]
for start_index, end_index in shard_weights_index
]
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
# Shard by row
else:
total_size = weight_tensor.size(0)
@ -985,12 +1016,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
else:
self.target_modules = self.default_target_modules
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if isinstance(module, (ReplicatedLinear, )):
self.unsharded_weights_modules.append(name)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif isinstance(module,
(QKVParallelLinear, MergedColumnParallelLinear)):
self.maybe_fused_weights_modules[name] = module.output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear, )):