[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)

This commit is contained in:
Jee Jee Li 2024-09-24 02:57:42 +08:00 committed by GitHub
parent 86e9c8df29
commit 9b0e3ec970
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 64 additions and 38 deletions

View File

@ -169,6 +169,7 @@ def test_punica_sgmv(
device, device,
) )
max_seq_length = seq_len_tensor.max() max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple): if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item() max_seq_length = max_seq_length[0].item()
else: else:
@ -183,6 +184,7 @@ def test_punica_sgmv(
lora_indices_tensor, lora_indices_tensor,
batches, batches,
max_seq_length, max_seq_length,
token_nums,
scaling, scaling,
) )
else: else:
@ -195,6 +197,7 @@ def test_punica_sgmv(
lora_indices_tensor, lora_indices_tensor,
batches, batches,
max_seq_length, max_seq_length,
token_nums,
add_inputs=True, add_inputs=True,
) )
ref_torch_groupgemm( ref_torch_groupgemm(
@ -347,6 +350,7 @@ def test_punica_expand_nslices(
device, device,
) )
max_seq_length = seq_len_tensor.max() max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple): if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item() max_seq_length = max_seq_length[0].item()
else: else:
@ -364,6 +368,7 @@ def test_punica_expand_nslices(
lora_indices_tensor, lora_indices_tensor,
batches, batches,
max_seq_length, max_seq_length,
token_nums,
slice_offset, slice_offset,
hidden_size, hidden_size,
add_inputs=True, add_inputs=True,

View File

@ -84,6 +84,7 @@ def test_punica_sgmv(
device, device,
) )
max_seq_length = seq_len_tensor.max() max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple): if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item() max_seq_length = max_seq_length[0].item()
else: else:
@ -98,6 +99,7 @@ def test_punica_sgmv(
lora_indices_tensor, lora_indices_tensor,
batches, batches,
max_seq_length, max_seq_length,
token_nums,
scaling, scaling,
) )
else: else:
@ -110,6 +112,7 @@ def test_punica_sgmv(
lora_indices_tensor, lora_indices_tensor,
batches, batches,
max_seq_length, max_seq_length,
token_nums,
add_inputs=True, add_inputs=True,
) )
ref_torch_groupgemm( ref_torch_groupgemm(
@ -262,6 +265,7 @@ def test_punica_expand_nslices(
device, device,
) )
max_seq_length = seq_len_tensor.max() max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple): if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item() max_seq_length = max_seq_length[0].item()
else: else:
@ -279,6 +283,7 @@ def test_punica_expand_nslices(
lora_indices_tensor, lora_indices_tensor,
batches, batches,
max_seq_length, max_seq_length,
token_nums,
slice_offset, slice_offset,
hidden_size, hidden_size,
add_inputs=True, add_inputs=True,

View File

@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be corresponding to each batch, An index of -1 means no lora should be
applied. applied.
batches (int): batch size batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output. results to the output.
""" """
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]

View File

@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be corresponding to each batch, An index of -1 means no lora should be
applied. applied.
slice_offst (int): output_tensor's offst slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size slice_size (int): current output_tensor's size
batches (int): batch size batches (int): batch size
add_inputs (bool, optional): Defaults to False. add_inputs (bool, optional): Defaults to False.

View File

@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
""" """
@ -117,15 +118,17 @@ def _sgmv_expand(
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10]. [0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be corresponding to each batch. An index of -1 means no lora should be
applied. applied.
batches (int): batch size batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences max_seq_length (int): The max sequence lengths of the sequences in the
in the batch batch.
add_inputs (bool, optional): Defaults to False. adds the final lora token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output. results to the output.
""" """
@ -134,6 +137,7 @@ def _sgmv_expand(
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1) assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches

View File

@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int,
slice_offset: int, slice_offset: int,
slice_size: int, slice_size: int,
add_inputs: bool = False, add_inputs: bool = False,
@ -126,7 +127,7 @@ def _sgmv_expand_slice(
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10]. [0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be corresponding to each batch. An index of -1 means no lora should be
@ -134,10 +135,12 @@ def _sgmv_expand_slice(
batches (int): batch size batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences max_seq_length (int): The max sequence lengths of the sequences
in the batch in the batch
slice_offst (int): output_tensor's offst token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False. adds the final lora add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.. results to the output.
""" """
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1) assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches

View File

@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int,
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
@ -122,14 +123,16 @@ def _sgmv_shrink(
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4]. [0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be corresponding to each batch. An index of -1 means no lora should be
applied. applied.
batches (int): batch size batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences max_seq_length (int): The max sequence lengths of the sequences in the
in the batch batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor. scaling (float): Scaling factor.
""" """
assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype == lora_a_weights.dtype
@ -138,6 +141,7 @@ def _sgmv_shrink(
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1) assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches

View File

@ -27,7 +27,7 @@ if TYPE_CHECKING:
def compute_meta( def compute_meta(
token_lora_tensor: torch.Tensor token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
""" """
Get the information required for the sgmv kernel. With the features: Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function 1. If consecutive requests in the batch use the same LoRA, this function
@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1]) b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item() max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0) batch_size = lora_indices_tensor.size(0)
no_lora = False no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether # -1 means no lora should be applied. Use `no_lora` to determine whether
@ -52,7 +52,7 @@ def compute_meta(
if batch_size == 1 and lora_indices_tensor == -1: if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora) batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized # TODO see if this can be vectorized
@ -216,6 +216,7 @@ class PunicaWrapper:
dtype=torch.long, dtype=torch.long,
device=device) device=device)
self.max_length: int = 0 self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1 self.batch_size: int = -1
self.is_prefill = False self.is_prefill = False
self.no_lora = False self.no_lora = False
@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor) long_lora_offsets_tensor)
else: else:
self._long_lora_indices.zero_() self._long_lora_indices.zero_()
self.indices_len[:] = indices_len self.indices_len[:] = indices_len
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) batch_size, max_length, token_nums,
no_lora) = compute_meta(token_lora_tensor)
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
b_seq_start_tensor) b_seq_start_tensor)
@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor) lora_indices_tensor)
self.batch_size = batch_size self.batch_size = batch_size
self.max_length = max_length self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora self.no_lora = no_lora
@property @property
def prefill_metadata( def prefill_metadata(
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: self
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
""" """
This property provides a convenient way to access the necessary This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations. metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions 1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths 2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of 3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied. -1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices 4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch 5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
""" """
return (self._seq_start_locs[:self.batch_size], return (self._seq_start_locs[:self.batch_size],
self._seq_lengths[:self.batch_size], self._seq_lengths[:self.batch_size],
self._lora_indices_per_batch[:self.batch_size], self._lora_indices_per_batch[:self.batch_size],
self.batch_size, self.max_length) self.batch_size, self.max_length, self.token_nums)
@property @property
def token_lora_indices(self) -> torch.Tensor: def token_lora_indices(self) -> torch.Tensor:
@ -324,7 +328,7 @@ class PunicaWrapper:
def sampler_indices(self) -> torch.Tensor: def sampler_indices(self) -> torch.Tensor:
""" """
This property is used to access the lora indices specifically for This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA LogitsProcessorWithLoRA.
""" """
sampler_indices_len = self.indices_len[1] sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len] return self._sampler_indices[:sampler_indices_len]
@ -332,7 +336,7 @@ class PunicaWrapper:
@property @property
def sampler_indices_padded(self) -> torch.Tensor: def sampler_indices_padded(self) -> torch.Tensor:
""" """
This property provides access to padded sampler indices This property provides access to padded sampler indices.
""" """
indices_padded_len = self.indices_len[2] indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len] return self._sampler_indices_padded[:indices_padded_len]
@ -341,7 +345,7 @@ class PunicaWrapper:
def embeddings_indices(self) -> torch.Tensor: def embeddings_indices(self) -> torch.Tensor:
""" """
This property provides access to the indices used for lora embeddings, This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA specifically for VocabParallelEmbeddingWithLoRA.
""" """
embeddings_indices_len = self.indices_len[3] embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len] return self._embeddings_indices[:, :embeddings_indices_len]
@ -350,7 +354,7 @@ class PunicaWrapper:
def long_lora_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor:
""" """
This property provides access to the indices used for long context This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora lora, specifically for LinearScalingRotaryEmbeddingWithLora.
""" """
long_lora_len = self.indices_len[4] long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len] return self._long_lora_indices[:long_lora_len]
@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor. scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting y_offset (Optional[int], optional): Offset to apply to the starting
column of y. column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.. y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None. buffer (Optional[torch.Tensor], optional): Defaults to None.
""" """
y_org = y y_org = y