mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 05:01:47 +08:00
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
218 lines
8.7 KiB
Python
218 lines
8.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import types
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
|
from vllm.model_executor.model_loader.utils import (
|
|
process_weights_after_loading)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Notes for Online Quantization
|
|
# In terms of state of checkpoints, quantization config and their
|
|
# correspondance to online quantization:
|
|
# | Use Case | Checkpoints | model_config.quantization |
|
|
# | no quant | high precision | None |
|
|
# | offline quant | quantized | fp8, torchao etc. |
|
|
# | online quant | high precision | torchao etc. |
|
|
#
|
|
# The process for loading non-quantized checkpoint
|
|
# 1. load non-quantized weights (load_weights)
|
|
# 2. do any additional post processing (process_weights_after_loading)
|
|
#
|
|
# The process for loading offline quantized checkpoint
|
|
# 1. load offline-quantized weights (load_weights)
|
|
# 2. do any additional post processing (process_weights_after_loading)
|
|
|
|
# The process for unquantized model reloading
|
|
# (repeated run in RL training loop)
|
|
# first run
|
|
# UI1. load_weights: load bfloat16 weights
|
|
# UI2. process_weights_after_loading: any additional post processing
|
|
# subsequent run
|
|
# UC1: load_weights: load bfloat16 weights
|
|
# (shouldn't be any issues since we didn't change any attributes
|
|
# of the weights)
|
|
# UC2: process_weights_after_loading: any additional post processing
|
|
|
|
# The process for weight reloading with online quantization
|
|
# (repeated run in RL training loop)
|
|
# first run
|
|
# I1. load_weights: load bfloat16 weights
|
|
# I2. process_weights_after_loading:
|
|
# record weight metadata and attributes for R1 and R2
|
|
# quantize weights to fp8
|
|
# subsequent run
|
|
# (beginning model weight is in fp8)
|
|
# load_weights:
|
|
# R1. restore bfloat16 model weight metadata
|
|
# R2. restore the model weight attributes
|
|
# R3. reload bfloat16 weights
|
|
# R4. quantize weights (by calling process_weights_after_loading),
|
|
# also set `process_weights_after_loading_already_called` to
|
|
# True to stop it from running again
|
|
# process_weights_after_loading (if called):
|
|
# this will be skipped since it's already ran in
|
|
# load_weights
|
|
|
|
|
|
def maybe_save_metadata_and_attributes_for_weight_reloading(
|
|
model: nn.Module, model_config: ModelConfig):
|
|
# following is to support on the fly quantization, currently only supported
|
|
# for torchao
|
|
if model_config.quantization != "torchao":
|
|
return
|
|
|
|
if getattr(model, "process_weights_after_loading_already_called", False):
|
|
# In case `process_weights_after_loading` is called multiple times
|
|
# we'll skip it at later times
|
|
logger.warning(
|
|
"process_weights_after_loading already called for model %s", model)
|
|
return
|
|
|
|
from vllm.model_executor.model_loader.weight_utils import get_quant_config
|
|
quant_config = get_quant_config(model_config, None)
|
|
|
|
# If checkpoint is already torchao serialized, this means it's
|
|
# pre-quantized quantization case, we'll skip saving the metadata
|
|
# Otherwise, this is Step I2 of initialization steps of
|
|
# online quantization
|
|
# This step record the weights metadata and weight attributes so we can
|
|
# restore the bfloat16 model weights during the relad step (R1 and R2)
|
|
# see Notes in online_quantization.py for more details
|
|
if not (hasattr(quant_config, "is_checkpoint_torchao_serialized") and \
|
|
not quant_config.is_checkpoint_torchao_serialized):
|
|
return
|
|
|
|
# This is the I2 step of online quantiztion that saves
|
|
# metadata and attributes of weights so they can be used in R1 and
|
|
# R2 step, note that we only save these during initialization
|
|
|
|
# Includes two things
|
|
# 1. save floating point metadata (shape, dtype, device) for init
|
|
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
|
|
|
|
if getattr(model, "weight_metadata_and_attr_saved", False):
|
|
return
|
|
|
|
# save the dtype, shape and device for model parameter, used for
|
|
# restoring the model high precision parameters before
|
|
# reloading the weights
|
|
assert not hasattr(model, "original_weights_rebuild_keys")
|
|
model.original_weights_rebuild_keys = {}
|
|
for name, p in model.named_parameters():
|
|
model.original_weights_rebuild_keys[name] = {
|
|
"shape": p.shape,
|
|
"dtype": p.dtype,
|
|
"device": p.device,
|
|
}
|
|
|
|
# record the weight attributes (loader functions etc.)
|
|
# so these can be recovered later when we reload the weights
|
|
# structure: {"weight_name": {"weight_attr_key": attr}}
|
|
assert not hasattr(model, "recorded_weight_attr")
|
|
model.recorded_weight_attr = {}
|
|
for name, param in model.named_parameters():
|
|
model.recorded_weight_attr[name] = {}
|
|
for key in param.__dict__:
|
|
if hasattr(param, key):
|
|
attr = getattr(param, key)
|
|
if not callable(attr):
|
|
model.recorded_weight_attr[name][key] = attr
|
|
elif hasattr(attr, "__self__") and param is attr.__self__:
|
|
# if attr is a bonded method for an instance, and
|
|
# attr.__self__ points to the instance (param)
|
|
# we'll record the underlying function object
|
|
model.recorded_weight_attr[name][key] = attr.__func__
|
|
else:
|
|
model.recorded_weight_attr[name][key] = attr
|
|
# mark the metadata and attributes saved so we don't run it again
|
|
model.weight_metadata_and_attr_saved = True
|
|
|
|
|
|
def _bond_method_to_cls(func, obj):
|
|
if hasattr(func, "__self__") or not callable(func):
|
|
# If the function is already bound to an instance, return it as is
|
|
return func
|
|
else:
|
|
return types.MethodType(func, obj)
|
|
|
|
|
|
def load_weights_and_online_quantize(model_loader: DefaultModelLoader,
|
|
model: nn.Module,
|
|
model_config: ModelConfig) -> set[str]:
|
|
# online quantization, right now only enabled for
|
|
# torchao
|
|
# R1, R2, R3, R4 in the Notes
|
|
|
|
# TODO: Add fp8 support
|
|
assert model_config.quantization == "torchao", "online " \
|
|
"quantization is only enabled for torchao currently"
|
|
# TODO: use create_weights to restore the weights to original state
|
|
|
|
# Step R1: First restore the quantized weights to original bfloat16
|
|
# weights, with original metadata (shape, dtype, device)
|
|
# and attributes, so that bfloat16 weights can be loaded properly
|
|
existing_param_names = dict(
|
|
model.named_parameters(remove_duplicate=False)).keys()
|
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
|
model_device = None
|
|
|
|
# Step R2: recover the parameter to the state before first loading
|
|
for name, d in model.original_weights_rebuild_keys.items():
|
|
_shape = d["shape"]
|
|
_dtype = d["dtype"]
|
|
_device = d["device"]
|
|
if model_device is not None:
|
|
assert model_device == _device, "Expecting all weights " \
|
|
"to be in the same device for now, got both: " \
|
|
f"{model_device} and {_device}"
|
|
else:
|
|
model_device = _device
|
|
|
|
if name in existing_param_names:
|
|
module_name, weight_name = name.rsplit(".", 1)
|
|
module = named_modules[module_name]
|
|
setattr(
|
|
module, weight_name,
|
|
torch.nn.Parameter(
|
|
torch.empty(_shape, dtype=_dtype, device=_device)))
|
|
|
|
# recorded_weight_attr is
|
|
# {"weight_name": {"weight_attr_key": attr}}
|
|
# e.g.
|
|
# {
|
|
# {
|
|
# "layer.0.weight": {
|
|
# "weight_loader": weight_loader_function_object,
|
|
# "input_dim": 0, ...
|
|
# },
|
|
# "layer.1.weight": ...,
|
|
# }
|
|
# }
|
|
for full_weight_name, weight_attr_dict in \
|
|
model.recorded_weight_attr.items():
|
|
for attr_name, attr in weight_attr_dict.items():
|
|
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
|
module = named_modules[module_name]
|
|
weight = getattr(module, weight_name)
|
|
if not hasattr(weight, attr_name):
|
|
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
|
|
|
# Step I1: reload bfloat16 / high precision weights
|
|
loaded_weights = model.load_weights(
|
|
model_loader.get_all_weights(model_config, model))
|
|
|
|
# Step I2: online quantize the weights
|
|
# manually process weights after loading
|
|
model.process_weights_after_loading_already_called = False
|
|
process_weights_after_loading(model, model_config, model_device)
|
|
model.process_weights_after_loading_already_called = True
|
|
return loaded_weights
|