mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Bugfix] Add fully sharded layer for QKVParallelLinearWithLora (#5665)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
c35e4a3dd7
commit
67005a07bc
@ -64,7 +64,8 @@ def test_baichuan_lora(baichuan_lora_files):
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
|
||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 4:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||
@ -75,7 +76,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=1,
|
||||
trust_remote_code=True)
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=fully_sharded)
|
||||
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
@ -87,7 +89,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=2,
|
||||
trust_remote_code=True)
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=fully_sharded)
|
||||
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
|
||||
|
||||
del llm_tp2
|
||||
@ -101,10 +104,11 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=4,
|
||||
trust_remote_code=True)
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=fully_sharded)
|
||||
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
|
||||
|
||||
del llm_tp4
|
||||
cleanup()
|
||||
|
||||
assert output_tp1 == output_tp4
|
||||
assert output_tp1 == output_tp4
|
||||
|
||||
@ -12,7 +12,8 @@ from vllm.config import LoRAConfig
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
|
||||
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
|
||||
RowParallelLinearWithShardedLoRA)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
@ -684,7 +685,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = QKVParallelLinearWithLora(linear)
|
||||
lora_linear = QKVParallelLinearWithLora(
|
||||
linear
|
||||
) if not fully_shard else QKVParallelLinearWithShardedLora(linear)
|
||||
|
||||
@dataclass
|
||||
class FakeConfig:
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
|
||||
|
||||
@ -90,11 +91,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
def _mcp_apply(x, bias, layer):
|
||||
"""
|
||||
MergedColumnParallelLinearWithShardedLoRA and
|
||||
QKVParallelLinearWithShardedLora share the same
|
||||
MergedQKVParallelLinearWithShardedLora share the same
|
||||
LoRa weight application method.
|
||||
|
||||
The main difference is the step by shard_size for lora_b which can
|
||||
vary for QKVParallelLinearWithShardedLora but is constant for
|
||||
vary for MergedQKVParallelLinearWithShardedLora but is constant for
|
||||
MergedColumnParallelLinearWithShardedLoRA.
|
||||
"""
|
||||
# expecting 2 for column parallel and 3 for qkv
|
||||
@ -167,7 +168,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
||||
)
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
||||
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
||||
"""
|
||||
Differs from QKVParallelLinearWithLora by slicing the
|
||||
LoRA A's also.
|
||||
@ -175,6 +176,57 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked.shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
buffer = tensor_model_parallel_all_gather(buffer)
|
||||
bgmv(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
# now have column partitioned output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
||||
"""
|
||||
Differs from MergedQKVParallelLinearWithLora by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
|
||||
@ -641,6 +641,24 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
lora_b_q = lora_b[:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
k_offset = self.q_proj_total_size
|
||||
lora_b_k = lora_b[:, k_offset +
|
||||
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
v_offset = k_offset + self.kv_proj_total_size
|
||||
lora_b_v = lora_b[:, v_offset +
|
||||
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
||||
return lora_b
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
@ -650,21 +668,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
):
|
||||
self.reset_lora(index)
|
||||
if self.tp_size > 1:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
lora_b_q = lora_b[:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
k_offset = self.q_proj_total_size
|
||||
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
|
||||
self.kv_shard_id:k_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
v_offset = k_offset + self.kv_proj_total_size
|
||||
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
|
||||
self.kv_shard_id:v_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
@ -674,6 +679,7 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
|
||||
@ -8,7 +8,8 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
|
||||
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
|
||||
RowParallelLinearWithShardedLoRA)
|
||||
# being imported for _all_lora_classes below
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -35,6 +36,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||
RowParallelLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithShardedLora,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user