mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 18:45:02 +08:00
Improve TransformersModel UX (#12785)
This commit is contained in:
parent
56534cd577
commit
1a6fcad4c9
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
"""Wrapper around `transformers` models"""
|
||||
import re
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Iterable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
|
||||
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
||||
|
||||
|
||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||
|
||||
|
||||
def replace_linear_class(
|
||||
linear: nn.Linear,
|
||||
style: str,
|
||||
style: Literal["colwise", "rowwise"],
|
||||
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
|
||||
"""
|
||||
In model configurations, we use a neutral type (string) to specify parallel
|
||||
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
|
||||
|
||||
Quant config is not supported yet
|
||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||
|
||||
`quant_config` is not yet supported.
|
||||
Args:
|
||||
linear (nn.Linear): `nn.Linear` to be replaced.
|
||||
style (str): Tensor parallel style of the new linear, e.g. "colwise".
|
||||
quant_config (QuantConfig): Quantization config for the new linear.
|
||||
Returns:
|
||||
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
|
||||
"""
|
||||
|
||||
if not isinstance(style, str):
|
||||
@ -93,7 +102,10 @@ def replace_linear_class(
|
||||
}.get(style)
|
||||
|
||||
if vllm_linear_cls is None:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
logger.warning(
|
||||
"Unsupported parallel style value: %s. "
|
||||
"This layer will not be tensor parallelized.", style)
|
||||
return linear
|
||||
|
||||
class HFCompatibleLinear(vllm_linear_cls):
|
||||
"""
|
||||
@ -119,25 +131,24 @@ class TransformersModel(nn.Module):
|
||||
super().__init__()
|
||||
logger.info("Using Transformers backend.")
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
|
||||
self.model: PreTrainedModel = AutoModel.from_config(
|
||||
self.config,
|
||||
attn_implementation="vllm",
|
||||
torch_dtype=vllm_config.model_config.dtype,
|
||||
trust_remote_code=vllm_config.model_config.trust_remote_code,
|
||||
)
|
||||
prefix = self.model.base_model_prefix
|
||||
|
||||
# MLP modifications
|
||||
self.tensor_parallelize(self.model)
|
||||
self.apply_base_model_tp_plan(self.model)
|
||||
|
||||
# Attention modifications (assumes 1 attention op per hidden layer)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -170,13 +181,13 @@ class TransformersModel(nn.Module):
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def log_replacement(self, name: str, old_module: nn.Module,
|
||||
new_module: nn.Module):
|
||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||
|
||||
def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
|
||||
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
|
||||
"""
|
||||
Apply the base model tensor parallelization plan to a module.
|
||||
Currently only supports linear layers.
|
||||
"""
|
||||
if (self.config.base_model_tp_plan is None
|
||||
and self.vllm_config.parallel_config.tensor_parallel_size > 1):
|
||||
and get_tensor_model_parallel_world_size() > 1):
|
||||
raise ValueError(
|
||||
"Trying to run tensor parallelization but the model does not "
|
||||
"support it yet!")
|
||||
@ -189,9 +200,9 @@ class TransformersModel(nn.Module):
|
||||
new_module = replace_linear_class(child_module, style,
|
||||
self.quant_config)
|
||||
setattr(module, child_name, new_module)
|
||||
self.log_replacement(qual_name, child_module, new_module)
|
||||
log_replacement(qual_name, child_module, new_module)
|
||||
else:
|
||||
self.tensor_parallelize(child_module, prefix=qual_name)
|
||||
self.apply_base_model_tp_plan(child_module, prefix=qual_name)
|
||||
|
||||
def replace_vocab_embed_class(self, module: nn.Module):
|
||||
# Use native set input embeddings
|
||||
@ -201,8 +212,8 @@ class TransformersModel(nn.Module):
|
||||
org_num_embeddings=self.config.vocab_size,
|
||||
quant_config=None,
|
||||
)
|
||||
self.log_replacement("input embedding",
|
||||
self.model.get_input_embeddings(), new_module)
|
||||
log_replacement("input embedding", self.model.get_input_embeddings(),
|
||||
new_module)
|
||||
self.model.set_input_embeddings(new_module)
|
||||
|
||||
def forward(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user