mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)
This commit is contained in:
parent
86e9c8df29
commit
9b0e3ec970
@ -169,6 +169,7 @@ def test_punica_sgmv(
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
token_nums = seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
@ -183,6 +184,7 @@ def test_punica_sgmv(
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
else:
|
||||
@ -195,6 +197,7 @@ def test_punica_sgmv(
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
@ -347,6 +350,7 @@ def test_punica_expand_nslices(
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
token_nums = seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
@ -364,6 +368,7 @@ def test_punica_expand_nslices(
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
|
||||
@ -84,6 +84,7 @@ def test_punica_sgmv(
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
token_nums = seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
@ -98,6 +99,7 @@ def test_punica_sgmv(
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
else:
|
||||
@ -110,6 +112,7 @@ def test_punica_sgmv(
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
@ -262,6 +265,7 @@ def test_punica_expand_nslices(
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
token_nums = seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
@ -279,6 +283,7 @@ def test_punica_expand_nslices(
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
|
||||
@ -100,7 +100,7 @@ def _bgmv_expand(
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@ -104,7 +104,7 @@ def _bgmv_expand_slice(
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
slice_offst (int): output_tensor's offst
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False.
|
||||
|
||||
@ -106,6 +106,7 @@ def _sgmv_expand(
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@ -115,17 +116,19 @@ def _sgmv_expand(
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g.,if the sequence length is [4, 6], it is
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
||||
length of the sequences in the batch
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
@ -134,6 +137,7 @@ def _sgmv_expand(
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
|
||||
@ -112,6 +112,7 @@ def _sgmv_expand_slice(
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
@ -124,20 +125,22 @@ def _sgmv_expand_slice(
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g.,if the sequence length is [4, 6], it is
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
||||
length of the sequences in the batch
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
slice_offst (int): output_tensor's offst
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
||||
results to the output..
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
@ -145,6 +148,7 @@ def _sgmv_expand_slice(
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
|
||||
@ -110,6 +110,7 @@ def _sgmv_shrink(
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
@ -120,17 +121,19 @@ def _sgmv_shrink(
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g.,if the sequence length is [4, 6], it is
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
||||
length of the sequences in the batch
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
scaling (float): Scaling factor.
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
@ -138,6 +141,7 @@ def _sgmv_shrink(
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
|
||||
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
|
||||
def compute_meta(
|
||||
token_lora_tensor: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
|
||||
"""
|
||||
Get the information required for the sgmv kernel. With the features:
|
||||
1. If consecutive requests in the batch use the same LoRA, this function
|
||||
@ -43,7 +43,7 @@ def compute_meta(
|
||||
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
||||
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
||||
max_length = seq_length_tensor.max().item()
|
||||
|
||||
token_nums = seq_length_tensor.sum().item()
|
||||
batch_size = lora_indices_tensor.size(0)
|
||||
no_lora = False
|
||||
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
||||
@ -52,7 +52,7 @@ def compute_meta(
|
||||
if batch_size == 1 and lora_indices_tensor == -1:
|
||||
no_lora = True
|
||||
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||
batch_size, max_length, no_lora)
|
||||
batch_size, max_length, token_nums, no_lora)
|
||||
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
@ -178,7 +178,7 @@ def convert_mapping(
|
||||
class PunicaWrapper:
|
||||
"""
|
||||
PunicaWrapper is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica kernel.
|
||||
"""
|
||||
|
||||
@ -216,6 +216,7 @@ class PunicaWrapper:
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.max_length: int = 0
|
||||
self.token_nums: int = 0
|
||||
self.batch_size: int = -1
|
||||
self.is_prefill = False
|
||||
self.no_lora = False
|
||||
@ -276,13 +277,13 @@ class PunicaWrapper:
|
||||
long_lora_offsets_tensor)
|
||||
else:
|
||||
self._long_lora_indices.zero_()
|
||||
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
|
||||
|
||||
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
|
||||
batch_size, max_length, token_nums,
|
||||
no_lora) = compute_meta(token_lora_tensor)
|
||||
|
||||
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
|
||||
b_seq_start_tensor)
|
||||
@ -291,25 +292,28 @@ class PunicaWrapper:
|
||||
lora_indices_tensor)
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
self.token_nums = token_nums
|
||||
self.no_lora = no_lora
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
|
||||
self
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||
"""
|
||||
This property provides a convenient way to access the necessary
|
||||
metadata for prefill-related kernel computations.
|
||||
1. seq_start_locs: Tensor of sequence start positions
|
||||
2. seq_lengths: Tensor of sequence lengths
|
||||
1. seq_start_locs: Tensor of sequence start positions.
|
||||
2. seq_lengths: Tensor of sequence lengths.
|
||||
3. lora_indices_per_batch: Tensor of lora indices, and an index of
|
||||
-1 means no lora should be applied.
|
||||
4. batch_size: batch size after clustering identical lora indices
|
||||
5. max_length: The maximum sequence length in the batch
|
||||
4. batch_size: Batch size after clustering identical lora indices.
|
||||
5. max_length: The maximum sequence length in the batch.
|
||||
6. token_nums: The token numbers in the batch.
|
||||
"""
|
||||
return (self._seq_start_locs[:self.batch_size],
|
||||
self._seq_lengths[:self.batch_size],
|
||||
self._lora_indices_per_batch[:self.batch_size],
|
||||
self.batch_size, self.max_length)
|
||||
self.batch_size, self.max_length, self.token_nums)
|
||||
|
||||
@property
|
||||
def token_lora_indices(self) -> torch.Tensor:
|
||||
@ -324,7 +328,7 @@ class PunicaWrapper:
|
||||
def sampler_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property is used to access the lora indices specifically for
|
||||
LogitsProcessorWithLoRA
|
||||
LogitsProcessorWithLoRA.
|
||||
"""
|
||||
sampler_indices_len = self.indices_len[1]
|
||||
return self._sampler_indices[:sampler_indices_len]
|
||||
@ -332,7 +336,7 @@ class PunicaWrapper:
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
indices_padded_len = self.indices_len[2]
|
||||
return self._sampler_indices_padded[:indices_padded_len]
|
||||
@ -341,7 +345,7 @@ class PunicaWrapper:
|
||||
def embeddings_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for lora embeddings,
|
||||
specifically for VocabParallelEmbeddingWithLoRA
|
||||
specifically for VocabParallelEmbeddingWithLoRA.
|
||||
"""
|
||||
embeddings_indices_len = self.indices_len[3]
|
||||
return self._embeddings_indices[:, :embeddings_indices_len]
|
||||
@ -350,7 +354,7 @@ class PunicaWrapper:
|
||||
def long_lora_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for long context
|
||||
lora, specifically for LinearScalingRotaryEmbeddingWithLora
|
||||
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
|
||||
"""
|
||||
long_lora_len = self.indices_len[4]
|
||||
return self._long_lora_indices[:long_lora_len]
|
||||
@ -524,7 +528,7 @@ class PunicaWrapper:
|
||||
scale (float): Scaling factor.
|
||||
y_offset (Optional[int], optional): Offset to apply to the starting
|
||||
column of y.
|
||||
y_slice_size (Optional[int], optional): Size of the y column slice..
|
||||
y_slice_size (Optional[int], optional): Size of the y column slice.
|
||||
buffer (Optional[torch.Tensor], optional): Defaults to None.
|
||||
"""
|
||||
y_org = y
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user