mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 00:05:56 +08:00
Refactor Linear handling in TransformersModel (#12727)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
64862d106e
commit
249824c3bf
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -47,8 +47,8 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
|||||||
|
|
||||||
|
|
||||||
def adjust_bitsandbytes_4bit_shard(param: Parameter,
|
def adjust_bitsandbytes_4bit_shard(param: Parameter,
|
||||||
shard_offsets: Dict[str, Tuple[int, int]],
|
shard_offsets: dict[str, tuple[int, int]],
|
||||||
loaded_shard_id: str) -> Tuple[int, int]:
|
loaded_shard_id: str) -> tuple[int, int]:
|
||||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||||
|
|
||||||
total, _ = shard_offsets["total"]
|
total, _ = shard_offsets["total"]
|
||||||
@ -90,7 +90,7 @@ class LinearMethodBase(QuantizeMethodBase):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
output_partition_sizes: List[int], input_size: int,
|
output_partition_sizes: list[int], input_size: int,
|
||||||
output_size: int, params_dtype: torch.dtype,
|
output_size: int, params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs):
|
**extra_weight_attrs):
|
||||||
"""Create weights for a linear layer.
|
"""Create weights for a linear layer.
|
||||||
@ -123,7 +123,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
output_partition_sizes: List[int], input_size: int,
|
output_partition_sizes: list[int], input_size: int,
|
||||||
output_size: int, params_dtype: torch.dtype,
|
output_size: int, params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs):
|
**extra_weight_attrs):
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||||
@ -179,7 +179,8 @@ class LinearBase(torch.nn.Module):
|
|||||||
self.quant_method = quant_config.get_quant_method(self,
|
self.quant_method = quant_config.get_quant_method(self,
|
||||||
prefix=prefix)
|
prefix=prefix)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self,
|
||||||
|
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -240,9 +241,8 @@ class ReplicatedLinear(LinearBase):
|
|||||||
assert param.size() == loaded_weight.size()
|
assert param.size() == loaded_weight.size()
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
self, x: torch.Tensor
|
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output = self.quant_method.apply(self, x, bias)
|
output = self.quant_method.apply(self, x, bias)
|
||||||
@ -288,7 +288,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
output_sizes: Optional[List[int]] = None,
|
output_sizes: Optional[list[int]] = None,
|
||||||
prefix: str = ""):
|
prefix: str = ""):
|
||||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||||
quant_config, prefix)
|
quant_config, prefix)
|
||||||
@ -374,7 +374,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
@ -422,7 +422,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
input_size: int,
|
input_size: int,
|
||||||
output_sizes: List[int],
|
output_sizes: list[int],
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
@ -500,7 +500,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
current_shard_offset = 0
|
current_shard_offset = 0
|
||||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||||
False)
|
False)
|
||||||
shard_offsets: List[Tuple[int, int, int]] = []
|
shard_offsets: list[tuple[int, int, int]] = []
|
||||||
for i, output_size in enumerate(self.output_sizes):
|
for i, output_size in enumerate(self.output_sizes):
|
||||||
shard_offsets.append((i, current_shard_offset, output_size))
|
shard_offsets.append((i, current_shard_offset, output_size))
|
||||||
current_shard_offset += output_size
|
current_shard_offset += output_size
|
||||||
@ -602,7 +602,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_shard_offset = 0
|
current_shard_offset = 0
|
||||||
shard_offsets: List[Tuple[int, int, int]] = []
|
shard_offsets: list[tuple[int, int, int]] = []
|
||||||
for i, output_size in enumerate(self.output_sizes):
|
for i, output_size in enumerate(self.output_sizes):
|
||||||
shard_offsets.append((i, current_shard_offset, output_size))
|
shard_offsets.append((i, current_shard_offset, output_size))
|
||||||
current_shard_offset += output_size
|
current_shard_offset += output_size
|
||||||
@ -1124,7 +1124,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
# Copyright 2024 The vLLM team.
|
# Copyright 2024 The vLLM team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -14,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Wrapper around `transformers` models"""
|
"""Wrapper around `transformers` models"""
|
||||||
import re
|
import re
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -71,23 +72,10 @@ def vllm_flash_attention_forward(
|
|||||||
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
# Linear Layer that is compatible with transformers internal forward
|
def replace_linear_class(
|
||||||
# TODO: This is a temporary solution, we should find a better way to integrate
|
linear: nn.Linear,
|
||||||
class HFColumnParallelLinear(ColumnParallelLinear):
|
style: str,
|
||||||
|
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
return super().forward(input)[0]
|
|
||||||
|
|
||||||
|
|
||||||
class HFRowParallelLinear(RowParallelLinear):
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
return super().forward(input)[0]
|
|
||||||
|
|
||||||
|
|
||||||
def replace_tp_linear_class(orig_module: nn.Linear,
|
|
||||||
style: str,
|
|
||||||
quant_config=None):
|
|
||||||
"""
|
"""
|
||||||
In model configurations, we use a neutral type (string) to specify parallel
|
In model configurations, we use a neutral type (string) to specify parallel
|
||||||
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
|
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
|
||||||
@ -99,26 +87,28 @@ def replace_tp_linear_class(orig_module: nn.Linear,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported parallel style type {type(style)}, expected str")
|
f"Unsupported parallel style type {type(style)}, expected str")
|
||||||
|
|
||||||
input_size = orig_module.in_features
|
vllm_linear_cls = {
|
||||||
output_size = orig_module.out_features
|
"colwise": ColumnParallelLinear,
|
||||||
bias = orig_module.bias is not None
|
"rowwise": RowParallelLinear,
|
||||||
|
}.get(style)
|
||||||
|
|
||||||
if style == "colwise":
|
if vllm_linear_cls is None:
|
||||||
return HFColumnParallelLinear(
|
|
||||||
input_size,
|
|
||||||
output_size,
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
elif style == "rowwise":
|
|
||||||
return HFRowParallelLinear(
|
|
||||||
input_size,
|
|
||||||
output_size,
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
# We don't consider colwise_rep since it's used in lm_head
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||||
|
|
||||||
|
class HFCompatibleLinear(vllm_linear_cls):
|
||||||
|
"""
|
||||||
|
Wrapper class that removes `output_bias` from returned output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(input)[0]
|
||||||
|
|
||||||
|
return HFCompatibleLinear(
|
||||||
|
input_size=linear.in_features,
|
||||||
|
output_size=linear.out_features,
|
||||||
|
bias=linear.bias is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(nn.Module):
|
class TransformersModel(nn.Module):
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
@ -192,16 +182,16 @@ class TransformersModel(nn.Module):
|
|||||||
"support it yet!")
|
"support it yet!")
|
||||||
|
|
||||||
for child_name, child_module in module.named_children():
|
for child_name, child_module in module.named_children():
|
||||||
qual_name = prefix + child_name
|
qual_name = maybe_prefix(prefix, child_name)
|
||||||
for pattern, style in self.config.base_model_tp_plan.items():
|
for pattern, style in self.config.base_model_tp_plan.items():
|
||||||
if re.match(pattern, qual_name) and isinstance(
|
if re.match(pattern, qual_name) and isinstance(
|
||||||
child_module, nn.Linear):
|
child_module, nn.Linear):
|
||||||
new_module = replace_tp_linear_class(
|
new_module = replace_linear_class(child_module, style,
|
||||||
child_module, style, self.quant_config)
|
self.quant_config)
|
||||||
setattr(module, child_name, new_module)
|
setattr(module, child_name, new_module)
|
||||||
self.log_replacement(qual_name, child_module, new_module)
|
self.log_replacement(qual_name, child_module, new_module)
|
||||||
else:
|
else:
|
||||||
self.tensor_parallelize(child_module, prefix=f"{qual_name}.")
|
self.tensor_parallelize(child_module, prefix=qual_name)
|
||||||
|
|
||||||
def replace_vocab_embed_class(self, module: nn.Module):
|
def replace_vocab_embed_class(self, module: nn.Module):
|
||||||
# Use native set input embeddings
|
# Use native set input embeddings
|
||||||
@ -219,7 +209,7 @@ class TransformersModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor], # argument not used
|
kv_caches: list[torch.Tensor], # argument not used
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
@ -249,10 +239,10 @@ class TransformersModel(nn.Module):
|
|||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: Set[str] = set()
|
loaded_params = set[str]()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if name not in params_dict:
|
if name not in params_dict:
|
||||||
name = f"{self.model.base_model_prefix}.{name}"
|
name = f"{self.model.base_model_prefix}.{name}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user