[Bugfix] Add fully sharded layer for QKVParallelLinearWithLora (#5665)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Jee Li 2024-06-21 12:46:28 +08:00 committed by GitHub
parent c35e4a3dd7
commit 67005a07bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 26 deletions

View File

@ -64,7 +64,8 @@ def test_baichuan_lora(baichuan_lora_files):
@pytest.mark.skip("Requires multiple GPUs")
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
@ -75,7 +76,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
del llm_tp1
@ -87,7 +89,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
del llm_tp2
@ -101,10 +104,11 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
del llm_tp4
cleanup()
assert output_tp1 == output_tp4
assert output_tp1 == output_tp4

View File

@ -12,7 +12,8 @@ from vllm.config import LoRAConfig
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
@ -684,7 +685,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora(linear)
lora_linear = QKVParallelLinearWithLora(
linear
) if not fully_shard else QKVParallelLinearWithShardedLora(linear)
@dataclass
class FakeConfig:

View File

@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
@ -90,11 +91,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
def _mcp_apply(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
@ -167,7 +168,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
@ -175,6 +176,57 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from MergedQKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:

View File

@ -641,6 +641,24 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
return lora_b
def set_lora(
self,
index: int,
@ -650,21 +668,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@ -674,6 +679,7 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_b.T, non_blocking=True)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:

View File

@ -8,7 +8,8 @@ from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
@ -35,6 +36,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA,