mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)
This commit is contained in:
parent
86e9c8df29
commit
9b0e3ec970
@ -169,6 +169,7 @@ def test_punica_sgmv(
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
max_seq_length = seq_len_tensor.max()
|
max_seq_length = seq_len_tensor.max()
|
||||||
|
token_nums = seq_len_tensor.sum().item()
|
||||||
if isinstance(max_seq_length, tuple):
|
if isinstance(max_seq_length, tuple):
|
||||||
max_seq_length = max_seq_length[0].item()
|
max_seq_length = max_seq_length[0].item()
|
||||||
else:
|
else:
|
||||||
@ -183,6 +184,7 @@ def test_punica_sgmv(
|
|||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
batches,
|
batches,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -195,6 +197,7 @@ def test_punica_sgmv(
|
|||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
batches,
|
batches,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
)
|
)
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
@ -347,6 +350,7 @@ def test_punica_expand_nslices(
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
max_seq_length = seq_len_tensor.max()
|
max_seq_length = seq_len_tensor.max()
|
||||||
|
token_nums = seq_len_tensor.sum().item()
|
||||||
if isinstance(max_seq_length, tuple):
|
if isinstance(max_seq_length, tuple):
|
||||||
max_seq_length = max_seq_length[0].item()
|
max_seq_length = max_seq_length[0].item()
|
||||||
else:
|
else:
|
||||||
@ -364,6 +368,7 @@ def test_punica_expand_nslices(
|
|||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
batches,
|
batches,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
slice_offset,
|
slice_offset,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
|
|||||||
@ -84,6 +84,7 @@ def test_punica_sgmv(
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
max_seq_length = seq_len_tensor.max()
|
max_seq_length = seq_len_tensor.max()
|
||||||
|
token_nums = seq_len_tensor.sum().item()
|
||||||
if isinstance(max_seq_length, tuple):
|
if isinstance(max_seq_length, tuple):
|
||||||
max_seq_length = max_seq_length[0].item()
|
max_seq_length = max_seq_length[0].item()
|
||||||
else:
|
else:
|
||||||
@ -98,6 +99,7 @@ def test_punica_sgmv(
|
|||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
batches,
|
batches,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -110,6 +112,7 @@ def test_punica_sgmv(
|
|||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
batches,
|
batches,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
)
|
)
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
@ -262,6 +265,7 @@ def test_punica_expand_nslices(
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
max_seq_length = seq_len_tensor.max()
|
max_seq_length = seq_len_tensor.max()
|
||||||
|
token_nums = seq_len_tensor.sum().item()
|
||||||
if isinstance(max_seq_length, tuple):
|
if isinstance(max_seq_length, tuple):
|
||||||
max_seq_length = max_seq_length[0].item()
|
max_seq_length = max_seq_length[0].item()
|
||||||
else:
|
else:
|
||||||
@ -279,6 +283,7 @@ def test_punica_expand_nslices(
|
|||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
batches,
|
batches,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
slice_offset,
|
slice_offset,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
|
|||||||
@ -100,7 +100,7 @@ def _bgmv_expand(
|
|||||||
corresponding to each batch, An index of -1 means no lora should be
|
corresponding to each batch, An index of -1 means no lora should be
|
||||||
applied.
|
applied.
|
||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||||
results to the output.
|
results to the output.
|
||||||
"""
|
"""
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|||||||
@ -104,7 +104,7 @@ def _bgmv_expand_slice(
|
|||||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||||
corresponding to each batch, An index of -1 means no lora should be
|
corresponding to each batch, An index of -1 means no lora should be
|
||||||
applied.
|
applied.
|
||||||
slice_offst (int): output_tensor's offst
|
slice_offset (int): output_tensor's offset
|
||||||
slice_size (int): current output_tensor's size
|
slice_size (int): current output_tensor's size
|
||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
add_inputs (bool, optional): Defaults to False.
|
add_inputs (bool, optional): Defaults to False.
|
||||||
|
|||||||
@ -106,6 +106,7 @@ def _sgmv_expand(
|
|||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -117,15 +118,17 @@ def _sgmv_expand(
|
|||||||
sequence lengths of the sequences in the batch, used to index
|
sequence lengths of the sequences in the batch, used to index
|
||||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||||
[0, 4, 10].
|
[0, 4, 10].
|
||||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||||
length of the sequences in the batch
|
length of the sequences in the batch.
|
||||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||||
corresponding to each batch. An index of -1 means no lora should be
|
corresponding to each batch. An index of -1 means no lora should be
|
||||||
applied.
|
applied.
|
||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
max_seq_length (int): The max sequence lengths of the sequences
|
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||||
in the batch
|
batch.
|
||||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||||
|
token numbers in the inputs matches the one in the metadata.
|
||||||
|
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||||
results to the output.
|
results to the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -134,6 +137,7 @@ def _sgmv_expand(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
]
|
]
|
||||||
|
assert inputs.size(0) == token_nums
|
||||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||||
assert b_seq_start_loc.size(0) == batches
|
assert b_seq_start_loc.size(0) == batches
|
||||||
assert lora_indices_tensor.size(0) == batches
|
assert lora_indices_tensor.size(0) == batches
|
||||||
|
|||||||
@ -112,6 +112,7 @@ def _sgmv_expand_slice(
|
|||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
@ -126,7 +127,7 @@ def _sgmv_expand_slice(
|
|||||||
sequence lengths of the sequences in the batch, used to index
|
sequence lengths of the sequences in the batch, used to index
|
||||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||||
[0, 4, 10].
|
[0, 4, 10].
|
||||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||||
length of the sequences in the batch
|
length of the sequences in the batch
|
||||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||||
corresponding to each batch. An index of -1 means no lora should be
|
corresponding to each batch. An index of -1 means no lora should be
|
||||||
@ -134,10 +135,12 @@ def _sgmv_expand_slice(
|
|||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
max_seq_length (int): The max sequence lengths of the sequences
|
max_seq_length (int): The max sequence lengths of the sequences
|
||||||
in the batch
|
in the batch
|
||||||
slice_offst (int): output_tensor's offst
|
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||||
|
token numbers in the inputs matches the one in the metadata.
|
||||||
|
slice_offset (int): output_tensor's offset
|
||||||
slice_size (int): current output_tensor's size
|
slice_size (int): current output_tensor's size
|
||||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||||
results to the output..
|
results to the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
@ -145,6 +148,7 @@ def _sgmv_expand_slice(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
]
|
]
|
||||||
|
assert inputs.size(0) == token_nums
|
||||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||||
assert b_seq_start_loc.size(0) == batches
|
assert b_seq_start_loc.size(0) == batches
|
||||||
assert lora_indices_tensor.size(0) == batches
|
assert lora_indices_tensor.size(0) == batches
|
||||||
|
|||||||
@ -110,6 +110,7 @@ def _sgmv_shrink(
|
|||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -122,14 +123,16 @@ def _sgmv_shrink(
|
|||||||
sequence lengths of the sequences in the batch, used to index
|
sequence lengths of the sequences in the batch, used to index
|
||||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||||
[0, 4].
|
[0, 4].
|
||||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||||
length of the sequences in the batch
|
length of the sequences in the batch.
|
||||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||||
corresponding to each batch. An index of -1 means no lora should be
|
corresponding to each batch. An index of -1 means no lora should be
|
||||||
applied.
|
applied.
|
||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
max_seq_length (int): The max sequence lengths of the sequences
|
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||||
in the batch
|
batch.
|
||||||
|
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||||
|
token numbers in the inputs matches the one in the metadata.
|
||||||
scaling (float): Scaling factor.
|
scaling (float): Scaling factor.
|
||||||
"""
|
"""
|
||||||
assert inputs.dtype == lora_a_weights.dtype
|
assert inputs.dtype == lora_a_weights.dtype
|
||||||
@ -138,6 +141,7 @@ def _sgmv_shrink(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
]
|
]
|
||||||
|
assert inputs.size(0) == token_nums
|
||||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||||
assert b_seq_start_loc.size(0) == batches
|
assert b_seq_start_loc.size(0) == batches
|
||||||
assert lora_indices_tensor.size(0) == batches
|
assert lora_indices_tensor.size(0) == batches
|
||||||
|
|||||||
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def compute_meta(
|
def compute_meta(
|
||||||
token_lora_tensor: torch.Tensor
|
token_lora_tensor: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
|
||||||
"""
|
"""
|
||||||
Get the information required for the sgmv kernel. With the features:
|
Get the information required for the sgmv kernel. With the features:
|
||||||
1. If consecutive requests in the batch use the same LoRA, this function
|
1. If consecutive requests in the batch use the same LoRA, this function
|
||||||
@ -43,7 +43,7 @@ def compute_meta(
|
|||||||
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
||||||
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
||||||
max_length = seq_length_tensor.max().item()
|
max_length = seq_length_tensor.max().item()
|
||||||
|
token_nums = seq_length_tensor.sum().item()
|
||||||
batch_size = lora_indices_tensor.size(0)
|
batch_size = lora_indices_tensor.size(0)
|
||||||
no_lora = False
|
no_lora = False
|
||||||
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
||||||
@ -52,7 +52,7 @@ def compute_meta(
|
|||||||
if batch_size == 1 and lora_indices_tensor == -1:
|
if batch_size == 1 and lora_indices_tensor == -1:
|
||||||
no_lora = True
|
no_lora = True
|
||||||
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||||
batch_size, max_length, no_lora)
|
batch_size, max_length, token_nums, no_lora)
|
||||||
|
|
||||||
|
|
||||||
# TODO see if this can be vectorized
|
# TODO see if this can be vectorized
|
||||||
@ -216,6 +216,7 @@ class PunicaWrapper:
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
self.max_length: int = 0
|
self.max_length: int = 0
|
||||||
|
self.token_nums: int = 0
|
||||||
self.batch_size: int = -1
|
self.batch_size: int = -1
|
||||||
self.is_prefill = False
|
self.is_prefill = False
|
||||||
self.no_lora = False
|
self.no_lora = False
|
||||||
@ -276,13 +277,13 @@ class PunicaWrapper:
|
|||||||
long_lora_offsets_tensor)
|
long_lora_offsets_tensor)
|
||||||
else:
|
else:
|
||||||
self._long_lora_indices.zero_()
|
self._long_lora_indices.zero_()
|
||||||
|
|
||||||
self.indices_len[:] = indices_len
|
self.indices_len[:] = indices_len
|
||||||
|
|
||||||
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
|
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
|
||||||
|
|
||||||
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||||
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
|
batch_size, max_length, token_nums,
|
||||||
|
no_lora) = compute_meta(token_lora_tensor)
|
||||||
|
|
||||||
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
|
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
|
||||||
b_seq_start_tensor)
|
b_seq_start_tensor)
|
||||||
@ -291,25 +292,28 @@ class PunicaWrapper:
|
|||||||
lora_indices_tensor)
|
lora_indices_tensor)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.token_nums = token_nums
|
||||||
self.no_lora = no_lora
|
self.no_lora = no_lora
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(
|
def prefill_metadata(
|
||||||
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
|
self
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||||
"""
|
"""
|
||||||
This property provides a convenient way to access the necessary
|
This property provides a convenient way to access the necessary
|
||||||
metadata for prefill-related kernel computations.
|
metadata for prefill-related kernel computations.
|
||||||
1. seq_start_locs: Tensor of sequence start positions
|
1. seq_start_locs: Tensor of sequence start positions.
|
||||||
2. seq_lengths: Tensor of sequence lengths
|
2. seq_lengths: Tensor of sequence lengths.
|
||||||
3. lora_indices_per_batch: Tensor of lora indices, and an index of
|
3. lora_indices_per_batch: Tensor of lora indices, and an index of
|
||||||
-1 means no lora should be applied.
|
-1 means no lora should be applied.
|
||||||
4. batch_size: batch size after clustering identical lora indices
|
4. batch_size: Batch size after clustering identical lora indices.
|
||||||
5. max_length: The maximum sequence length in the batch
|
5. max_length: The maximum sequence length in the batch.
|
||||||
|
6. token_nums: The token numbers in the batch.
|
||||||
"""
|
"""
|
||||||
return (self._seq_start_locs[:self.batch_size],
|
return (self._seq_start_locs[:self.batch_size],
|
||||||
self._seq_lengths[:self.batch_size],
|
self._seq_lengths[:self.batch_size],
|
||||||
self._lora_indices_per_batch[:self.batch_size],
|
self._lora_indices_per_batch[:self.batch_size],
|
||||||
self.batch_size, self.max_length)
|
self.batch_size, self.max_length, self.token_nums)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_lora_indices(self) -> torch.Tensor:
|
def token_lora_indices(self) -> torch.Tensor:
|
||||||
@ -324,7 +328,7 @@ class PunicaWrapper:
|
|||||||
def sampler_indices(self) -> torch.Tensor:
|
def sampler_indices(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This property is used to access the lora indices specifically for
|
This property is used to access the lora indices specifically for
|
||||||
LogitsProcessorWithLoRA
|
LogitsProcessorWithLoRA.
|
||||||
"""
|
"""
|
||||||
sampler_indices_len = self.indices_len[1]
|
sampler_indices_len = self.indices_len[1]
|
||||||
return self._sampler_indices[:sampler_indices_len]
|
return self._sampler_indices[:sampler_indices_len]
|
||||||
@ -332,7 +336,7 @@ class PunicaWrapper:
|
|||||||
@property
|
@property
|
||||||
def sampler_indices_padded(self) -> torch.Tensor:
|
def sampler_indices_padded(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This property provides access to padded sampler indices
|
This property provides access to padded sampler indices.
|
||||||
"""
|
"""
|
||||||
indices_padded_len = self.indices_len[2]
|
indices_padded_len = self.indices_len[2]
|
||||||
return self._sampler_indices_padded[:indices_padded_len]
|
return self._sampler_indices_padded[:indices_padded_len]
|
||||||
@ -341,7 +345,7 @@ class PunicaWrapper:
|
|||||||
def embeddings_indices(self) -> torch.Tensor:
|
def embeddings_indices(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This property provides access to the indices used for lora embeddings,
|
This property provides access to the indices used for lora embeddings,
|
||||||
specifically for VocabParallelEmbeddingWithLoRA
|
specifically for VocabParallelEmbeddingWithLoRA.
|
||||||
"""
|
"""
|
||||||
embeddings_indices_len = self.indices_len[3]
|
embeddings_indices_len = self.indices_len[3]
|
||||||
return self._embeddings_indices[:, :embeddings_indices_len]
|
return self._embeddings_indices[:, :embeddings_indices_len]
|
||||||
@ -350,7 +354,7 @@ class PunicaWrapper:
|
|||||||
def long_lora_indices(self) -> torch.Tensor:
|
def long_lora_indices(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This property provides access to the indices used for long context
|
This property provides access to the indices used for long context
|
||||||
lora, specifically for LinearScalingRotaryEmbeddingWithLora
|
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
|
||||||
"""
|
"""
|
||||||
long_lora_len = self.indices_len[4]
|
long_lora_len = self.indices_len[4]
|
||||||
return self._long_lora_indices[:long_lora_len]
|
return self._long_lora_indices[:long_lora_len]
|
||||||
@ -524,7 +528,7 @@ class PunicaWrapper:
|
|||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
y_offset (Optional[int], optional): Offset to apply to the starting
|
y_offset (Optional[int], optional): Offset to apply to the starting
|
||||||
column of y.
|
column of y.
|
||||||
y_slice_size (Optional[int], optional): Size of the y column slice..
|
y_slice_size (Optional[int], optional): Size of the y column slice.
|
||||||
buffer (Optional[torch.Tensor], optional): Defaults to None.
|
buffer (Optional[torch.Tensor], optional): Defaults to None.
|
||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user