mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 06:44:25 +08:00
[Bugfix] Fix Phi-3 BNB quantization with tensor parallel (#9948)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
a111d0151f
commit
b6374e09b0
@ -1,3 +1,4 @@
|
|||||||
|
import itertools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -41,12 +42,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
|||||||
|
|
||||||
|
|
||||||
def adjust_bitsandbytes_4bit_shard(param: Parameter,
|
def adjust_bitsandbytes_4bit_shard(param: Parameter,
|
||||||
qkv_offsets: Dict[str, Tuple[int, int]],
|
shard_offsets: Dict[str, Tuple[int, int]],
|
||||||
loaded_shard_id: str) -> Tuple[int, int]:
|
loaded_shard_id: str) -> Tuple[int, int]:
|
||||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||||
|
|
||||||
total, _ = qkv_offsets["total"]
|
total, _ = shard_offsets["total"]
|
||||||
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
|
orig_offset, orig_size = shard_offsets[loaded_shard_id]
|
||||||
|
|
||||||
quantized_total = param.data.shape[0]
|
quantized_total = param.data.shape[0]
|
||||||
quantized_offset = orig_offset * quantized_total // total
|
quantized_offset = orig_offset * quantized_total // total
|
||||||
@ -499,9 +500,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
# Special case for Marlin.
|
# Special case for Marlin.
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
param, shard_size, shard_offset)
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
if use_bitsandbytes_4bit:
|
if use_bitsandbytes_4bit:
|
||||||
shard_size = loaded_weight.shape[output_dim] // 2
|
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||||
shard_offset = shard_size * shard_id
|
orig_offsets = {
|
||||||
|
str(i): (index[i], size)
|
||||||
|
for i, size in enumerate(self.output_sizes)
|
||||||
|
}
|
||||||
|
orig_offsets["total"] = (self.output_size, 0)
|
||||||
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||||
|
param, orig_offsets, str(shard_id))
|
||||||
|
|
||||||
loaded_weight_shard = loaded_weight.narrow(
|
loaded_weight_shard = loaded_weight.narrow(
|
||||||
output_dim, shard_offset, shard_size)
|
output_dim, shard_offset, shard_size)
|
||||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import dataclasses
|
|||||||
import fnmatch
|
import fnmatch
|
||||||
import glob
|
import glob
|
||||||
import inspect
|
import inspect
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -27,7 +28,9 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizeMethodBase)
|
QuantizeMethodBase)
|
||||||
@ -936,6 +939,34 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
end_index = total_size // tp_size * (tp_rank + 1)
|
end_index = total_size // tp_size * (tp_rank + 1)
|
||||||
weight_sub_tensor = weight_tensor[...,
|
weight_sub_tensor = weight_tensor[...,
|
||||||
start_index:end_index]
|
start_index:end_index]
|
||||||
|
# Weights have fused on disk. In this case, we assume that the
|
||||||
|
# weight and module use same name.
|
||||||
|
elif any(
|
||||||
|
weight_name.startswith(module)
|
||||||
|
for module in self.maybe_fused_weights_modules):
|
||||||
|
# special case for fused weights
|
||||||
|
# get the size of each shard weight tensor
|
||||||
|
total_shard_sizes = next(
|
||||||
|
(sizes for module, sizes in
|
||||||
|
self.maybe_fused_weights_modules.items()
|
||||||
|
if weight_name.startswith(module)))
|
||||||
|
total_size = weight_tensor.size(0)
|
||||||
|
assert total_size == sum(total_shard_sizes)
|
||||||
|
# get the start/end index of each shard weight tensor
|
||||||
|
total_start_index = list(
|
||||||
|
itertools.accumulate([0] + total_shard_sizes))[:-1]
|
||||||
|
shard_weights_index = [
|
||||||
|
(idx + size // tp_size * tp_rank,
|
||||||
|
idx + size // tp_size * (tp_rank + 1))
|
||||||
|
for idx, size in zip(total_start_index,
|
||||||
|
total_shard_sizes)
|
||||||
|
]
|
||||||
|
# slice and reorder the weight tensor
|
||||||
|
weight_tensor = [
|
||||||
|
weight_tensor[start_index:end_index, ...]
|
||||||
|
for start_index, end_index in shard_weights_index
|
||||||
|
]
|
||||||
|
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
|
||||||
# Shard by row
|
# Shard by row
|
||||||
else:
|
else:
|
||||||
total_size = weight_tensor.size(0)
|
total_size = weight_tensor.size(0)
|
||||||
@ -985,12 +1016,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
else:
|
else:
|
||||||
self.target_modules = self.default_target_modules
|
self.target_modules = self.default_target_modules
|
||||||
|
|
||||||
|
# Modules whose weights might have fused on disk
|
||||||
|
# we need their output_sizes to make shard in flight correctly with TP
|
||||||
|
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# Some modules like `ReplicatedLinear` should not have their weights
|
# Some modules like `ReplicatedLinear` should not have their weights
|
||||||
# sharded. The reason for implementing it this way is to avoid new
|
# sharded. The reason for implementing it this way is to avoid new
|
||||||
# static variable in the model implementation.
|
# static variable in the model implementation.
|
||||||
if isinstance(module, (ReplicatedLinear, )):
|
if isinstance(module, (ReplicatedLinear, )):
|
||||||
self.unsharded_weights_modules.append(name)
|
self.unsharded_weights_modules.append(name)
|
||||||
|
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
|
||||||
|
# fused weights on disk. We need to use the output sizes of these
|
||||||
|
# modules to shard the weights correctly.
|
||||||
|
elif isinstance(module,
|
||||||
|
(QKVParallelLinear, MergedColumnParallelLinear)):
|
||||||
|
self.maybe_fused_weights_modules[name] = module.output_sizes
|
||||||
# In TP, these weights are partitioned along the column
|
# In TP, these weights are partitioned along the column
|
||||||
# dimension (dim=-1)
|
# dimension (dim=-1)
|
||||||
elif isinstance(module, (RowParallelLinear, )):
|
elif isinstance(module, (RowParallelLinear, )):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user