diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 08f1e103e53b7..da8db08fe7152 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,7 +2,7 @@ import itertools from abc import abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -47,8 +47,8 @@ def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_bitsandbytes_4bit_shard(param: Parameter, - shard_offsets: Dict[str, Tuple[int, int]], - loaded_shard_id: str) -> Tuple[int, int]: + shard_offsets: dict[str, tuple[int, int]], + loaded_shard_id: str) -> tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" total, _ = shard_offsets["total"] @@ -90,7 +90,7 @@ class LinearMethodBase(QuantizeMethodBase): @abstractmethod def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """Create weights for a linear layer. @@ -123,7 +123,7 @@ class UnquantizedLinearMethod(LinearMethodBase): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): weight = Parameter(torch.empty(sum(output_partition_sizes), @@ -179,7 +179,8 @@ class LinearBase(torch.nn.Module): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, + x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: raise NotImplementedError @@ -240,9 +241,8 @@ class ReplicatedLinear(LinearBase): assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) - def forward( - self, x: torch.Tensor - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + def forward(self, + x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None output = self.quant_method.apply(self, x, bias) @@ -288,7 +288,7 @@ class ColumnParallelLinear(LinearBase): skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, + output_sizes: Optional[list[int]] = None, prefix: str = ""): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) @@ -374,7 +374,7 @@ class ColumnParallelLinear(LinearBase): loaded_weight = loaded_weight.reshape(1) param.load_column_parallel_weight(loaded_weight=loaded_weight) - def forward(self, input_): + def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. @@ -422,7 +422,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): def __init__(self, input_size: int, - output_sizes: List[int], + output_sizes: list[int], bias: bool = True, gather_output: bool = False, skip_bias_add: bool = False, @@ -500,7 +500,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): current_shard_offset = 0 use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) - shard_offsets: List[Tuple[int, int, int]] = [] + shard_offsets: list[tuple[int, int, int]] = [] for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size @@ -602,7 +602,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): """ current_shard_offset = 0 - shard_offsets: List[Tuple[int, int, int]] = [] + shard_offsets: list[tuple[int, int, int]] = [] for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size @@ -1124,7 +1124,7 @@ class RowParallelLinear(LinearBase): param.load_row_parallel_weight(loaded_weight=loaded_weight) - def forward(self, input_): + def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: if self.input_is_parallel: input_parallel = input_ else: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 160beaa146ea3..dfc7143823d5a 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +15,7 @@ # limitations under the License. """Wrapper around `transformers` models""" import re -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Union import torch from torch import nn @@ -71,23 +72,10 @@ def vllm_flash_attention_forward( ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward -# Linear Layer that is compatible with transformers internal forward -# TODO: This is a temporary solution, we should find a better way to integrate -class HFColumnParallelLinear(ColumnParallelLinear): - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input)[0] - - -class HFRowParallelLinear(RowParallelLinear): - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input)[0] - - -def replace_tp_linear_class(orig_module: nn.Linear, - style: str, - quant_config=None): +def replace_linear_class( + linear: nn.Linear, + style: str, + quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]: """ In model configurations, we use a neutral type (string) to specify parallel styles, here we use it to translate nn.Linear into vllm-style tp Linear. @@ -99,26 +87,28 @@ def replace_tp_linear_class(orig_module: nn.Linear, raise ValueError( f"Unsupported parallel style type {type(style)}, expected str") - input_size = orig_module.in_features - output_size = orig_module.out_features - bias = orig_module.bias is not None + vllm_linear_cls = { + "colwise": ColumnParallelLinear, + "rowwise": RowParallelLinear, + }.get(style) - if style == "colwise": - return HFColumnParallelLinear( - input_size, - output_size, - bias, - ) - elif style == "rowwise": - return HFRowParallelLinear( - input_size, - output_size, - bias, - ) - # We don't consider colwise_rep since it's used in lm_head - else: + if vllm_linear_cls is None: raise ValueError(f"Unsupported parallel style value: {style}") + class HFCompatibleLinear(vllm_linear_cls): + """ + Wrapper class that removes `output_bias` from returned output. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input)[0] + + return HFCompatibleLinear( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + ) + class TransformersModel(nn.Module): embedding_padding_modules = ["lm_head"] @@ -192,16 +182,16 @@ class TransformersModel(nn.Module): "support it yet!") for child_name, child_module in module.named_children(): - qual_name = prefix + child_name + qual_name = maybe_prefix(prefix, child_name) for pattern, style in self.config.base_model_tp_plan.items(): if re.match(pattern, qual_name) and isinstance( child_module, nn.Linear): - new_module = replace_tp_linear_class( - child_module, style, self.quant_config) + new_module = replace_linear_class(child_module, style, + self.quant_config) setattr(module, child_name, new_module) self.log_replacement(qual_name, child_module, new_module) else: - self.tensor_parallelize(child_module, prefix=f"{qual_name}.") + self.tensor_parallelize(child_module, prefix=qual_name) def replace_vocab_embed_class(self, module: nn.Module): # Use native set input embeddings @@ -219,7 +209,7 @@ class TransformersModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], # argument not used + kv_caches: list[torch.Tensor], # argument not used attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -249,10 +239,10 @@ class TransformersModel(nn.Module): next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() + loaded_params = set[str]() for name, loaded_weight in weights: if name not in params_dict: name = f"{self.model.base_model_prefix}.{name}"