Refactor Linear handling in TransformersModel (#12727)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-02-05 04:31:12 +00:00 committed by GitHub
parent 64862d106e
commit 249824c3bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 58 deletions

View File

@ -2,7 +2,7 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, List, Optional, Tuple from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -47,8 +47,8 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
def adjust_bitsandbytes_4bit_shard(param: Parameter, def adjust_bitsandbytes_4bit_shard(param: Parameter,
shard_offsets: Dict[str, Tuple[int, int]], shard_offsets: dict[str, tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]: loaded_shard_id: str) -> tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = shard_offsets["total"] total, _ = shard_offsets["total"]
@ -90,7 +90,7 @@ class LinearMethodBase(QuantizeMethodBase):
@abstractmethod @abstractmethod
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
"""Create weights for a linear layer. """Create weights for a linear layer.
@ -123,7 +123,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(torch.empty(sum(output_partition_sizes),
@ -179,7 +179,8 @@ class LinearBase(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, self.quant_method = quant_config.get_quant_method(self,
prefix=prefix) prefix=prefix)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self,
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
raise NotImplementedError raise NotImplementedError
@ -240,9 +241,8 @@ class ReplicatedLinear(LinearBase):
assert param.size() == loaded_weight.size() assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward( def forward(self,
self, x: torch.Tensor x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias)
@ -288,7 +288,7 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None, output_sizes: Optional[list[int]] = None,
prefix: str = ""): prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix) quant_config, prefix)
@ -374,7 +374,7 @@ class ColumnParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight) param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_): def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
@ -422,7 +422,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
def __init__(self, def __init__(self,
input_size: int, input_size: int,
output_sizes: List[int], output_sizes: list[int],
bias: bool = True, bias: bool = True,
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
@ -500,7 +500,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset = 0 current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False) False)
shard_offsets: List[Tuple[int, int, int]] = [] shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes): for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
@ -602,7 +602,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
""" """
current_shard_offset = 0 current_shard_offset = 0
shard_offsets: List[Tuple[int, int, int]] = [] shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes): for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
@ -1124,7 +1124,7 @@ class RowParallelLinear(LinearBase):
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_): def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team. # Copyright 2024 The vLLM team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -14,7 +15,7 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
import re import re
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Union
import torch import torch
from torch import nn from torch import nn
@ -71,23 +72,10 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
# Linear Layer that is compatible with transformers internal forward def replace_linear_class(
# TODO: This is a temporary solution, we should find a better way to integrate linear: nn.Linear,
class HFColumnParallelLinear(ColumnParallelLinear): style: str,
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
class HFRowParallelLinear(RowParallelLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
def replace_tp_linear_class(orig_module: nn.Linear,
style: str,
quant_config=None):
""" """
In model configurations, we use a neutral type (string) to specify parallel In model configurations, we use a neutral type (string) to specify parallel
styles, here we use it to translate nn.Linear into vllm-style tp Linear. styles, here we use it to translate nn.Linear into vllm-style tp Linear.
@ -99,26 +87,28 @@ def replace_tp_linear_class(orig_module: nn.Linear,
raise ValueError( raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str") f"Unsupported parallel style type {type(style)}, expected str")
input_size = orig_module.in_features vllm_linear_cls = {
output_size = orig_module.out_features "colwise": ColumnParallelLinear,
bias = orig_module.bias is not None "rowwise": RowParallelLinear,
}.get(style)
if style == "colwise": if vllm_linear_cls is None:
return HFColumnParallelLinear(
input_size,
output_size,
bias,
)
elif style == "rowwise":
return HFRowParallelLinear(
input_size,
output_size,
bias,
)
# We don't consider colwise_rep since it's used in lm_head
else:
raise ValueError(f"Unsupported parallel style value: {style}") raise ValueError(f"Unsupported parallel style value: {style}")
class HFCompatibleLinear(vllm_linear_cls):
"""
Wrapper class that removes `output_bias` from returned output.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
return HFCompatibleLinear(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
)
class TransformersModel(nn.Module): class TransformersModel(nn.Module):
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
@ -192,16 +182,16 @@ class TransformersModel(nn.Module):
"support it yet!") "support it yet!")
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
qual_name = prefix + child_name qual_name = maybe_prefix(prefix, child_name)
for pattern, style in self.config.base_model_tp_plan.items(): for pattern, style in self.config.base_model_tp_plan.items():
if re.match(pattern, qual_name) and isinstance( if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear): child_module, nn.Linear):
new_module = replace_tp_linear_class( new_module = replace_linear_class(child_module, style,
child_module, style, self.quant_config) self.quant_config)
setattr(module, child_name, new_module) setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module) self.log_replacement(qual_name, child_module, new_module)
else: else:
self.tensor_parallelize(child_module, prefix=f"{qual_name}.") self.tensor_parallelize(child_module, prefix=qual_name)
def replace_vocab_embed_class(self, module: nn.Module): def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings # Use native set input embeddings
@ -219,7 +209,7 @@ class TransformersModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], # argument not used kv_caches: list[torch.Tensor], # argument not used
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
@ -249,10 +239,10 @@ class TransformersModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params = set[str]()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name not in params_dict: if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}" name = f"{self.model.base_model_prefix}.{name}"