[Misc] Minor improvements to the readability of PunicaWrapperBase (#11200)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-12-15 00:38:27 +08:00 committed by GitHub
parent ea7bd68d10
commit 3cb5769883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 25 deletions

View File

@ -63,7 +63,7 @@ class PunicaWrapperABC(ABC):
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_input=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -77,7 +77,7 @@ class PunicaWrapperABC(ABC):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: torch.Tensor, lora_b_stacked: torch.Tensor,
add_input: bool = True, add_inputs: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -367,12 +367,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_input=True, add_inputs=True,
**kwargs) -> None: **kwargs) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM and bias addition for multiple slices of lora_b.
Semantics: Semantics:
offset = offset_start
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
@ -386,7 +387,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight bias's weight
output_slices (Tuple[int, ...]): Every slice's size output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True. offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
""" """
# TODO: implement it based on torch ops # TODO: implement it based on torch ops
@ -397,7 +399,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: torch.Tensor, lora_b_stacked: torch.Tensor,
add_input: bool = True, add_inputs: bool = True,
**kwargs) -> None: **kwargs) -> None:
""" """
Applies lora specifically for VocabParallelEmbeddingWithLoRA. Applies lora specifically for VocabParallelEmbeddingWithLoRA.
@ -409,7 +411,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor. x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights. lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True. add_inputs (bool): Default to True.
""" """
# TODO: implement it based on torch ops # TODO: implement it based on torch ops
raise NotImplementedError raise NotImplementedError

View File

@ -67,7 +67,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
add_input: bool, add_inputs: bool,
): ):
#No LoRA request, so return directly #No LoRA request, so return directly
if self.no_lora: if self.no_lora:
@ -77,7 +77,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all, w_t_all,
y, y,
*self.prefill_metadata, *self.prefill_metadata,
add_input, add_inputs,
) )
def _expand_decode( def _expand_decode(
@ -85,9 +85,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
add_input: bool, add_inputs: bool,
): ):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill( def _expand_slice_prefill(
self, self,
@ -96,7 +96,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
y_offset: Optional[int], y_offset: Optional[int],
y_slice_size: Optional[int], y_slice_size: Optional[int],
add_input: bool, add_inputs: bool,
): ):
#No LoRA request, so return directly #No LoRA request, so return directly
if self.no_lora: if self.no_lora:
@ -108,7 +108,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
*self.prefill_metadata, *self.prefill_metadata,
y_offset, y_offset,
y_slice_size, y_slice_size,
add_input, add_inputs,
) )
def _expand_slice_decode( def _expand_slice_decode(
@ -118,10 +118,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
y_offset: Optional[int], y_offset: Optional[int],
y_slice_size: Optional[int], y_slice_size: Optional[int],
add_input: bool, add_inputs: bool,
): ):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input) y_slice_size, add_inputs)
def _apply_expand( def _apply_expand(
self, self,
@ -130,7 +130,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
y_offset: Optional[int], y_offset: Optional[int],
y_slice_size: Optional[int], y_slice_size: Optional[int],
add_input: bool = True, add_inputs: bool = True,
): ):
""" """
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
@ -141,7 +141,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
expand_slice_fun: Callable = (self._expand_slice_prefill expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else if self.is_prefill else
self._expand_slice_decode) self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float): w_t_all: torch.Tensor, scale: float):
@ -194,7 +194,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_input=True, add_inputs=True,
**kwargs) -> None: **kwargs) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM and bias addition for multiple slices of lora_b.
@ -213,7 +213,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight bias's weight
output_slices (Tuple[int, ...]): Every slice's size output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
@ -228,7 +228,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_b_stacked[slice_idx], lora_b_stacked[slice_idx],
offset_left, offset_left,
output_slices[slice_idx], output_slices[slice_idx],
add_input=add_input, add_inputs=add_inputs,
) )
offset_left += output_slices[slice_idx] offset_left += output_slices[slice_idx]
y = y.view_as(y_org) y = y.view_as(y_org)
@ -237,7 +237,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: torch.Tensor, lora_b_stacked: torch.Tensor,
add_input: bool = True, add_inputs: bool = True,
**kwargs) -> None: **kwargs) -> None:
""" """
Applies lora specifically for VocabParallelEmbeddingWithLoRA. Applies lora specifically for VocabParallelEmbeddingWithLoRA.
@ -249,13 +249,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor. x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights. lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True. add_inputs (bool): Default to True.
""" """
# Embedding layer only need expand op # Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode) if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_input) expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
@ -311,7 +311,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_b_stacked, lora_b_stacked,
None, None,
output_slices, output_slices,
add_input=True, add_inputs=True,
**kwargs) **kwargs)
def add_lora_logits(self, def add_lora_logits(self,

View File

@ -21,7 +21,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: torch.Tensor, lora_b_stacked: torch.Tensor,
add_input: bool = True, add_inputs: bool = True,
**kwargs) -> None: **kwargs) -> None:
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0) dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
@ -81,7 +81,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_input=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError