mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 13:41:52 +08:00
[mypy] Forward pass function type hints in lora (#11740)
Signed-off-by: lucast2021 <lucast2021@headroyce.org> Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
parent
022c5c6944
commit
9c749713f6
@ -405,7 +405,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
self.output_size = self.base_layer.output_size
|
||||
self.n_slices = 1
|
||||
|
||||
def forward(self, input_):
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""Forward of ReplicatedLinearWithLoRA
|
||||
|
||||
Args:
|
||||
@ -496,7 +498,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
bias = bias[start_idx:end_idx]
|
||||
return bias
|
||||
|
||||
def forward(self, input_):
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
@ -833,7 +837,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
return bias
|
||||
|
||||
def forward(self, input_):
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -219,6 +219,7 @@ class LoRAModel(AdapterModel):
|
||||
|
||||
config["vllm_max_position_embeddings"] = max_position_embeddings
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
unexpected_modules: List[Union[list[str], str]]
|
||||
if os.path.isfile(lora_tensor_path):
|
||||
tensors: Dict[str, torch.Tensor] = {}
|
||||
# Find unexpected modules.
|
||||
|
||||
@ -238,7 +238,9 @@ class ReplicatedLinear(LinearBase):
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user