mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 13:55:38 +08:00
[Misc] Minor improvements to the readability of PunicaWrapperBase (#11200)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
ea7bd68d10
commit
3cb5769883
@ -63,7 +63,7 @@ class PunicaWrapperABC(ABC):
|
|||||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
output_slices: Tuple[int, ...],
|
output_slices: Tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_input=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -77,7 +77,7 @@ class PunicaWrapperABC(ABC):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_b_stacked: torch.Tensor,
|
lora_b_stacked: torch.Tensor,
|
||||||
add_input: bool = True,
|
add_inputs: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -367,12 +367,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
output_slices: Tuple[int, ...],
|
output_slices: Tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_input=True,
|
add_inputs=True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||||
|
|
||||||
Semantics:
|
Semantics:
|
||||||
|
offset = offset_start
|
||||||
for i in range(len(lora_b_stacked)):
|
for i in range(len(lora_b_stacked)):
|
||||||
slice = output_slices[i]
|
slice = output_slices[i]
|
||||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||||
@ -386,7 +387,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||||
bias's weight
|
bias's weight
|
||||||
output_slices (Tuple[int, ...]): Every slice's size
|
output_slices (Tuple[int, ...]): Every slice's size
|
||||||
add_input (bool): Defaults to True.
|
offset_start (int): The starting position of y, defaults to 0
|
||||||
|
add_inputs (bool): Defaults to True.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# TODO: implement it based on torch ops
|
# TODO: implement it based on torch ops
|
||||||
@ -397,7 +399,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_b_stacked: torch.Tensor,
|
lora_b_stacked: torch.Tensor,
|
||||||
add_input: bool = True,
|
add_inputs: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||||
@ -409,7 +411,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (torch.Tensor): Input tensor.
|
x (torch.Tensor): Input tensor.
|
||||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||||
add_input (bool): Default to True.
|
add_inputs (bool): Default to True.
|
||||||
"""
|
"""
|
||||||
# TODO: implement it based on torch ops
|
# TODO: implement it based on torch ops
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
add_input: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
#No LoRA request, so return directly
|
#No LoRA request, so return directly
|
||||||
if self.no_lora:
|
if self.no_lora:
|
||||||
@ -77,7 +77,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
w_t_all,
|
w_t_all,
|
||||||
y,
|
y,
|
||||||
*self.prefill_metadata,
|
*self.prefill_metadata,
|
||||||
add_input,
|
add_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _expand_decode(
|
def _expand_decode(
|
||||||
@ -85,9 +85,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
add_input: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
|
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
||||||
|
|
||||||
def _expand_slice_prefill(
|
def _expand_slice_prefill(
|
||||||
self,
|
self,
|
||||||
@ -96,7 +96,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
y_offset: Optional[int],
|
y_offset: Optional[int],
|
||||||
y_slice_size: Optional[int],
|
y_slice_size: Optional[int],
|
||||||
add_input: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
#No LoRA request, so return directly
|
#No LoRA request, so return directly
|
||||||
if self.no_lora:
|
if self.no_lora:
|
||||||
@ -108,7 +108,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
*self.prefill_metadata,
|
*self.prefill_metadata,
|
||||||
y_offset,
|
y_offset,
|
||||||
y_slice_size,
|
y_slice_size,
|
||||||
add_input,
|
add_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _expand_slice_decode(
|
def _expand_slice_decode(
|
||||||
@ -118,10 +118,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
y_offset: Optional[int],
|
y_offset: Optional[int],
|
||||||
y_slice_size: Optional[int],
|
y_slice_size: Optional[int],
|
||||||
add_input: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||||
y_slice_size, add_input)
|
y_slice_size, add_inputs)
|
||||||
|
|
||||||
def _apply_expand(
|
def _apply_expand(
|
||||||
self,
|
self,
|
||||||
@ -130,7 +130,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
y_offset: Optional[int],
|
y_offset: Optional[int],
|
||||||
y_slice_size: Optional[int],
|
y_slice_size: Optional[int],
|
||||||
add_input: bool = True,
|
add_inputs: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||||
@ -141,7 +141,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||||
if self.is_prefill else
|
if self.is_prefill else
|
||||||
self._expand_slice_decode)
|
self._expand_slice_decode)
|
||||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||||
|
|
||||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor, scale: float):
|
w_t_all: torch.Tensor, scale: float):
|
||||||
@ -194,7 +194,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
output_slices: Tuple[int, ...],
|
output_slices: Tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_input=True,
|
add_inputs=True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||||
@ -213,7 +213,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||||
bias's weight
|
bias's weight
|
||||||
output_slices (Tuple[int, ...]): Every slice's size
|
output_slices (Tuple[int, ...]): Every slice's size
|
||||||
add_input (bool): Defaults to True.
|
add_inputs (bool): Defaults to True.
|
||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
@ -228,7 +228,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
lora_b_stacked[slice_idx],
|
lora_b_stacked[slice_idx],
|
||||||
offset_left,
|
offset_left,
|
||||||
output_slices[slice_idx],
|
output_slices[slice_idx],
|
||||||
add_input=add_input,
|
add_inputs=add_inputs,
|
||||||
)
|
)
|
||||||
offset_left += output_slices[slice_idx]
|
offset_left += output_slices[slice_idx]
|
||||||
y = y.view_as(y_org)
|
y = y.view_as(y_org)
|
||||||
@ -237,7 +237,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_b_stacked: torch.Tensor,
|
lora_b_stacked: torch.Tensor,
|
||||||
add_input: bool = True,
|
add_inputs: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||||
@ -249,13 +249,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (torch.Tensor): Input tensor.
|
x (torch.Tensor): Input tensor.
|
||||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||||
add_input (bool): Default to True.
|
add_inputs (bool): Default to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Embedding layer only need expand op
|
# Embedding layer only need expand op
|
||||||
expand_fun: Callable = (self._expand_prefill
|
expand_fun: Callable = (self._expand_prefill
|
||||||
if self.is_prefill else self._expand_decode)
|
if self.is_prefill else self._expand_decode)
|
||||||
expand_fun(y, x, lora_b_stacked, add_input)
|
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||||
|
|
||||||
def add_lora_linear(self,
|
def add_lora_linear(self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
@ -311,7 +311,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
None,
|
None,
|
||||||
output_slices,
|
output_slices,
|
||||||
add_input=True,
|
add_inputs=True,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def add_lora_logits(self,
|
def add_lora_logits(self,
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_b_stacked: torch.Tensor,
|
lora_b_stacked: torch.Tensor,
|
||||||
add_input: bool = True,
|
add_inputs: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
|
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
|
|||||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
output_slices: Tuple[int, ...],
|
output_slices: Tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_input=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user