[mypy] Forward pass function type hints in lora (#11740)

Signed-off-by: lucast2021 <lucast2021@headroyce.org>
Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
Lucas Tucker 2025-01-06 01:59:36 -06:00 committed by GitHub
parent 022c5c6944
commit 9c749713f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 5 deletions

View File

@ -405,7 +405,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
self.output_size = self.base_layer.output_size
self.n_slices = 1
def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ReplicatedLinearWithLoRA
Args:
@ -496,7 +498,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
bias = bias[start_idx:end_idx]
return bias
def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ColumnParallelLinear
Args:
@ -833,7 +837,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias
def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of RowParallelLinear
Args:

View File

@ -4,7 +4,7 @@ import math
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import safetensors.torch
import torch
@ -219,6 +219,7 @@ class LoRAModel(AdapterModel):
config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
unexpected_modules: List[Union[list[str], str]]
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.

View File

@ -238,7 +238,9 @@ class ReplicatedLinear(LinearBase):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)