[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)

This commit is contained in:
Jee Jee Li 2024-09-24 02:57:42 +08:00 committed by GitHub
parent 86e9c8df29
commit 9b0e3ec970
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 64 additions and 38 deletions

View File

@ -169,6 +169,7 @@ def test_punica_sgmv(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
@ -183,6 +184,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
@ -195,6 +197,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
@ -347,6 +350,7 @@ def test_punica_expand_nslices(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
@ -364,6 +368,7 @@ def test_punica_expand_nslices(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,

View File

@ -84,6 +84,7 @@ def test_punica_sgmv(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
@ -98,6 +99,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
@ -110,6 +112,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
@ -262,6 +265,7 @@ def test_punica_expand_nslices(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
@ -279,6 +283,7 @@ def test_punica_expand_nslices(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,

View File

@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]

View File

@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.

View File

@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
"""
@ -117,15 +118,17 @@ def _sgmv_expand(
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
@ -134,6 +137,7 @@ def _sgmv_expand(
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches

View File

@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
@ -126,7 +127,7 @@ def _sgmv_expand_slice(
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
@ -134,10 +135,12 @@ def _sgmv_expand_slice(
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output..
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches

View File

@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
@ -122,14 +123,16 @@ def _sgmv_shrink(
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
@ -138,6 +141,7 @@ def _sgmv_shrink(
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches

View File

@ -27,7 +27,7 @@ if TYPE_CHECKING:
def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
@ -52,7 +52,7 @@ def compute_meta(
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora)
batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized
@ -216,6 +216,7 @@ class PunicaWrapper:
dtype=torch.long,
device=device)
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
batch_size, max_length, token_nums,
no_lora) = compute_meta(token_lora_tensor)
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
b_seq_start_tensor)
@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor)
self.batch_size = batch_size
self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora
@property
def prefill_metadata(
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
self
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return (self._seq_start_locs[:self.batch_size],
self._seq_lengths[:self.batch_size],
self._lora_indices_per_batch[:self.batch_size],
self.batch_size, self.max_length)
self.batch_size, self.max_length, self.token_nums)
@property
def token_lora_indices(self) -> torch.Tensor:
@ -324,7 +328,7 @@ class PunicaWrapper:
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA.
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]
@ -332,7 +336,7 @@ class PunicaWrapper:
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices.
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
@ -341,7 +345,7 @@ class PunicaWrapper:
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA.
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
@ -350,7 +354,7 @@ class PunicaWrapper:
def long_lora_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice..
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org = y