mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 15:02:22 +08:00
[mypy] Forward pass function type hints in lora (#11740)
Signed-off-by: lucast2021 <lucast2021@headroyce.org> Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
parent
022c5c6944
commit
9c749713f6
@ -405,7 +405,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
self.output_size = self.base_layer.output_size
|
self.output_size = self.base_layer.output_size
|
||||||
self.n_slices = 1
|
self.n_slices = 1
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""Forward of ReplicatedLinearWithLoRA
|
"""Forward of ReplicatedLinearWithLoRA
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -496,7 +498,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
bias = bias[start_idx:end_idx]
|
bias = bias[start_idx:end_idx]
|
||||||
return bias
|
return bias
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""Forward of ColumnParallelLinear
|
"""Forward of ColumnParallelLinear
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -833,7 +837,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
return bias
|
return bias
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""Forward of RowParallelLinear
|
"""Forward of RowParallelLinear
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
@ -219,6 +219,7 @@ class LoRAModel(AdapterModel):
|
|||||||
|
|
||||||
config["vllm_max_position_embeddings"] = max_position_embeddings
|
config["vllm_max_position_embeddings"] = max_position_embeddings
|
||||||
peft_helper = PEFTHelper.from_dict(config)
|
peft_helper = PEFTHelper.from_dict(config)
|
||||||
|
unexpected_modules: List[Union[list[str], str]]
|
||||||
if os.path.isfile(lora_tensor_path):
|
if os.path.isfile(lora_tensor_path):
|
||||||
tensors: Dict[str, torch.Tensor] = {}
|
tensors: Dict[str, torch.Tensor] = {}
|
||||||
# Find unexpected modules.
|
# Find unexpected modules.
|
||||||
|
|||||||
@ -238,7 +238,9 @@ class ReplicatedLinear(LinearBase):
|
|||||||
assert param.size() == loaded_weight.size()
|
assert param.size() == loaded_weight.size()
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
|
self, x: torch.Tensor
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output = self.quant_method.apply(self, x, bias)
|
output = self.quant_method.apply(self, x, bias)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user