mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 07:17:00 +08:00
Move online quantization to model.load_weights (#26327)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
parent
1395461f5f
commit
da94c7c0eb
@ -62,7 +62,7 @@ ray.init()
|
|||||||
|
|
||||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||||
# Learn more about Ray placement groups:
|
# Learn more about Ray placement groups:
|
||||||
# https://docs.ray.io/en/latest/placement-groups.html
|
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
|
||||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||||
ray.get(pg_inference.ready())
|
ray.get(pg_inference.ready())
|
||||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||||
|
|||||||
162
examples/offline_inference/rlhf_online_quant.py
Normal file
162
examples/offline_inference/rlhf_online_quant.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
|
||||||
|
|
||||||
|
The script separates training and inference workloads onto distinct GPUs
|
||||||
|
so that Ray can manage process placement and inter-process communication.
|
||||||
|
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
|
||||||
|
tensor-parallel vLLM inference engine occupies GPU 1–2.
|
||||||
|
|
||||||
|
The example performs the following steps:
|
||||||
|
|
||||||
|
* Load the training model on GPU 0.
|
||||||
|
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
|
||||||
|
and Ray placement groups.
|
||||||
|
* Generate text from a list of prompts using the inference engine.
|
||||||
|
* Update the weights of the training model and broadcast the updated weights
|
||||||
|
to the inference engine by using a Ray collective RPC group. Note that
|
||||||
|
for demonstration purposes we simply zero out the weights.
|
||||||
|
|
||||||
|
For a production-ready implementation that supports multiple training and
|
||||||
|
inference replicas, see the OpenRLHF framework:
|
||||||
|
https://github.com/OpenRLHF/OpenRLHF
|
||||||
|
|
||||||
|
This example assumes a single-node cluster with three GPUs, but Ray
|
||||||
|
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||||
|
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||||
|
causes unexpected behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
from ray.util.placement_group import placement_group
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
from rlhf_utils import stateless_init_process_group
|
||||||
|
from torchao.core.config import config_to_dict
|
||||||
|
from torchao.quantization import (
|
||||||
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
|
PerRow,
|
||||||
|
)
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.utils.network_utils import get_ip, get_open_port
|
||||||
|
|
||||||
|
|
||||||
|
class MyLLM(LLM):
|
||||||
|
"""Configure the vLLM worker for Ray placement group execution."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||||
|
# so that vLLM can manage its own device placement within the worker.
|
||||||
|
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||||
|
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||||
|
train_model.to("cuda:0")
|
||||||
|
|
||||||
|
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||||
|
# be placed on GPUs 1 and 2.
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||||
|
# Learn more about Ray placement groups:
|
||||||
|
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
|
||||||
|
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||||
|
ray.get(pg_inference.ready())
|
||||||
|
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=pg_inference,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
placement_group_bundle_index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||||
|
# start-up latency.
|
||||||
|
|
||||||
|
# generate torchao quantization config for RL rollout
|
||||||
|
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
|
||||||
|
# use serialized config files instead of passing around json string
|
||||||
|
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
|
||||||
|
|
||||||
|
json_str = json.dumps(config_to_dict(config))
|
||||||
|
|
||||||
|
llm = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=0,
|
||||||
|
scheduling_strategy=scheduling_inference,
|
||||||
|
)(MyLLM).remote(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
hf_overrides={"quantization_config_dict_json": json_str},
|
||||||
|
enforce_eager=True,
|
||||||
|
worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate text from the prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0)
|
||||||
|
|
||||||
|
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# Set up the communication channel between the training process and the
|
||||||
|
# inference engine.
|
||||||
|
master_address = get_ip()
|
||||||
|
master_port = get_open_port()
|
||||||
|
|
||||||
|
handle = llm.collective_rpc.remote(
|
||||||
|
"init_weight_update_group", args=(master_address, master_port, 1, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_update_group = stateless_init_process_group(
|
||||||
|
master_address, master_port, 0, 3, torch.device("cuda:0")
|
||||||
|
)
|
||||||
|
ray.get(handle)
|
||||||
|
|
||||||
|
# Simulate a training step by zeroing out all model weights.
|
||||||
|
# In a real RLHF training loop the weights would be updated using the gradient
|
||||||
|
# from an RL objective such as PPO on a reward model.
|
||||||
|
for name, p in train_model.named_parameters():
|
||||||
|
p.data.zero_()
|
||||||
|
|
||||||
|
# Synchronize the updated weights to the inference engine.
|
||||||
|
for name, p in train_model.named_parameters():
|
||||||
|
dtype_name = str(p.dtype).split(".")[-1]
|
||||||
|
handle = llm.collective_rpc.remote(
|
||||||
|
"update_weight", args=(name, dtype_name, p.shape)
|
||||||
|
)
|
||||||
|
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||||
|
ray.get(handle)
|
||||||
|
|
||||||
|
# Verify that the inference weights have been updated.
|
||||||
|
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
||||||
|
|
||||||
|
# Generate text with the updated model. The output is expected to be nonsense
|
||||||
|
# because the weights are zero.
|
||||||
|
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||||
|
print("-" * 50)
|
||||||
|
for output in outputs_updated:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
fastsafetensors_weights_iterator,
|
fastsafetensors_weights_iterator,
|
||||||
filter_duplicate_safetensors_files,
|
filter_duplicate_safetensors_files,
|
||||||
filter_files_not_needed_for_inference,
|
filter_files_not_needed_for_inference,
|
||||||
|
get_quant_config,
|
||||||
maybe_download_from_modelscope,
|
maybe_download_from_modelscope,
|
||||||
multi_thread_pt_weights_iterator,
|
multi_thread_pt_weights_iterator,
|
||||||
multi_thread_safetensors_weights_iterator,
|
multi_thread_safetensors_weights_iterator,
|
||||||
@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||||
if model_config.quantization == "torchao" and torchao_version_at_least(
|
if model_config.quantization == "torchao":
|
||||||
"0.14.0"
|
quant_config = get_quant_config(model_config, self.load_config)
|
||||||
):
|
if (
|
||||||
self.load_config.safetensors_load_strategy = "torchao"
|
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||||
|
and quant_config.is_checkpoint_torchao_serialized
|
||||||
|
and torchao_version_at_least("0.14.0")
|
||||||
|
):
|
||||||
|
self.load_config.safetensors_load_strategy = "torchao"
|
||||||
|
|
||||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||||
|
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
|
||||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
|
||||||
# set to True, it means that this is either offline quantization case
|
|
||||||
# or the first run of online quantization
|
|
||||||
# see online_quantization.py for detailed notes
|
|
||||||
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
|
||||||
model, "weight_metadata_and_attr_saved", False
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_config.quantization is None:
|
|
||||||
# model is not quantized
|
|
||||||
loaded_weights = model.load_weights(
|
|
||||||
self.get_all_weights(model_config, model)
|
|
||||||
)
|
|
||||||
elif offline_quantization_or_first_run_of_online_quantization:
|
|
||||||
# case 1: offline quantized checkpoint
|
|
||||||
# case 2: Step I1 first run of weight loading with
|
|
||||||
# online quantization
|
|
||||||
# see online_quantization.py for detailed notes
|
|
||||||
loaded_weights = model.load_weights(
|
|
||||||
self.get_all_weights(model_config, model)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# to avoid circular dependency
|
|
||||||
from vllm.model_executor.model_loader.online_quantization import (
|
|
||||||
load_weights_and_online_quantize,
|
|
||||||
)
|
|
||||||
|
|
||||||
# subsequent runs of weight loading with online
|
|
||||||
# quantization
|
|
||||||
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
|
|
||||||
|
|
||||||
self.counter_after_loading_weights = time.perf_counter()
|
self.counter_after_loading_weights = time.perf_counter()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
|
|||||||
@ -2,13 +2,13 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import types
|
import types
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
|
||||||
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -56,6 +56,9 @@ logger = init_logger(__name__)
|
|||||||
# R4. quantize weights (by calling process_weights_after_loading),
|
# R4. quantize weights (by calling process_weights_after_loading),
|
||||||
# also set `process_weights_after_loading_already_called` to
|
# also set `process_weights_after_loading_already_called` to
|
||||||
# True to stop it from running again
|
# True to stop it from running again
|
||||||
|
# R5. (workaround for cudagraph), we restore the weight params to original quantized
|
||||||
|
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
|
||||||
|
# the weight update work well with cudagraph
|
||||||
# process_weights_after_loading (if called):
|
# process_weights_after_loading (if called):
|
||||||
# this will be skipped since it's already ran in
|
# this will be skipped since it's already ran in
|
||||||
# load_weights
|
# load_weights
|
||||||
@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
|
|||||||
if model_config.quantization != "torchao":
|
if model_config.quantization != "torchao":
|
||||||
return
|
return
|
||||||
|
|
||||||
if getattr(model, "process_weights_after_loading_already_called", False):
|
|
||||||
# In case `process_weights_after_loading` is called multiple times
|
|
||||||
# we'll skip it at later times
|
|
||||||
logger.warning(
|
|
||||||
"process_weights_after_loading already called for model %s", model
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import get_quant_config
|
from vllm.model_executor.model_loader.weight_utils import get_quant_config
|
||||||
|
|
||||||
quant_config = get_quant_config(model_config, None)
|
quant_config = get_quant_config(model_config, None)
|
||||||
@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
|
|||||||
else:
|
else:
|
||||||
model.recorded_weight_attr[name][key] = attr
|
model.recorded_weight_attr[name][key] = attr
|
||||||
# mark the metadata and attributes saved so we don't run it again
|
# mark the metadata and attributes saved so we don't run it again
|
||||||
|
model._model_config = model_config
|
||||||
model.weight_metadata_and_attr_saved = True
|
model.weight_metadata_and_attr_saved = True
|
||||||
|
|
||||||
|
|
||||||
@ -148,77 +144,132 @@ def _bond_method_to_cls(func, obj):
|
|||||||
return types.MethodType(func, obj)
|
return types.MethodType(func, obj)
|
||||||
|
|
||||||
|
|
||||||
def load_weights_and_online_quantize(
|
def support_quantized_model_reload_from_hp_weights(original_load_weights):
|
||||||
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
|
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
|
||||||
) -> set[str]:
|
reloading high precision (bfloat16/float16/float32) weight for an already quantized
|
||||||
|
model, this involves restoring the weights to a high precision weights and
|
||||||
|
then online quantize the weights
|
||||||
|
"""
|
||||||
# online quantization, right now only enabled for
|
# online quantization, right now only enabled for
|
||||||
# torchao
|
# torchao
|
||||||
# R1, R2, R3, R4 in the Notes
|
# R1, R2, R3, R4, R5 in the Notes
|
||||||
|
|
||||||
# TODO: Add fp8 support
|
def patched_model_load_weights(
|
||||||
assert model_config.quantization == "torchao", (
|
auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
|
||||||
"online quantization is only enabled for torchao currently"
|
) -> set[str]:
|
||||||
)
|
model = auto_weight_loader.module
|
||||||
# TODO: use create_weights to restore the weights to original state
|
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||||
|
model, "weight_metadata_and_attr_saved", False
|
||||||
|
)
|
||||||
|
|
||||||
# Step R1: First restore the quantized weights to original bfloat16
|
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||||
# weights, with original metadata (shape, dtype, device)
|
# set to True, it means that this is either offline quantization case
|
||||||
# and attributes, so that bfloat16 weights can be loaded properly
|
# or the first run of online quantization
|
||||||
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
|
# see Notes in this file for more details
|
||||||
named_modules = dict(model.named_modules(remove_duplicate=False))
|
if offline_quantization_or_first_run_of_online_quantization:
|
||||||
model_device = None
|
# case 1: offline quantized checkpoint
|
||||||
|
# case 2: Step I1 first run of weight loading with
|
||||||
|
# online quantization
|
||||||
|
return original_load_weights(auto_weight_loader, weights, mapper=mapper)
|
||||||
|
|
||||||
# Step R2: recover the parameter to the state before first loading
|
model_config = model._model_config
|
||||||
for name, d in model.original_weights_rebuild_keys.items():
|
|
||||||
_shape = d["shape"]
|
# TODO: Add fp8 support
|
||||||
_dtype = d["dtype"]
|
assert model_config.quantization == "torchao", (
|
||||||
_device = d["device"]
|
"online quantization is only enabled for torchao currently"
|
||||||
|
)
|
||||||
|
# TODO: use create_weights to restore the weights to original state
|
||||||
|
|
||||||
|
# Step R1: First restore the quantized weights to original bfloat16
|
||||||
|
# weights, with original metadata (shape, dtype, device)
|
||||||
|
# and attributes, so that bfloat16 weights can be loaded properly
|
||||||
|
# TODO: maybe set remove_duplicate to True?
|
||||||
|
original_quantized_weight_dict = dict(
|
||||||
|
model.named_parameters(remove_duplicate=False)
|
||||||
|
)
|
||||||
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||||
|
model_device = None
|
||||||
|
|
||||||
|
for name, d in model.original_weights_rebuild_keys.items():
|
||||||
|
_shape = d["shape"]
|
||||||
|
_dtype = d["dtype"]
|
||||||
|
_device = d["device"]
|
||||||
|
if model_device is not None:
|
||||||
|
assert model_device == _device, (
|
||||||
|
"Expecting all weights "
|
||||||
|
"to be in the same device for now, got both: "
|
||||||
|
f"{model_device} and {_device}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_device = _device
|
||||||
|
|
||||||
|
if name in original_quantized_weight_dict:
|
||||||
|
module_name, weight_name = name.rsplit(".", 1)
|
||||||
|
module = named_modules[module_name]
|
||||||
|
setattr(
|
||||||
|
module,
|
||||||
|
weight_name,
|
||||||
|
torch.nn.Parameter(
|
||||||
|
torch.empty(_shape, dtype=_dtype, device=_device),
|
||||||
|
requires_grad=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step R2: recover the weight attributes to the state before first loading
|
||||||
|
# recorded_weight_attr is
|
||||||
|
# {"weight_name": {"weight_attr_key": attr}}
|
||||||
|
# e.g.
|
||||||
|
# {
|
||||||
|
# {
|
||||||
|
# "layer.0.weight": {
|
||||||
|
# "weight_loader": weight_loader_function_object,
|
||||||
|
# "input_dim": 0, ...
|
||||||
|
# },
|
||||||
|
# "layer.1.weight": ...,
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
|
||||||
|
for attr_name, attr in weight_attr_dict.items():
|
||||||
|
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
||||||
|
module = named_modules[module_name]
|
||||||
|
weight = getattr(module, weight_name)
|
||||||
|
if not hasattr(weight, attr_name):
|
||||||
|
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
||||||
|
|
||||||
|
# Step R3: reload bfloat16 / high precision weights
|
||||||
|
updated_params = original_load_weights(
|
||||||
|
auto_weight_loader, weights, mapper=mapper
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step R4: online quantize the weights
|
||||||
|
# manually process weights after loading
|
||||||
|
model.process_weights_after_loading_already_called = False
|
||||||
if model_device is not None:
|
if model_device is not None:
|
||||||
assert model_device == _device, (
|
process_weights_after_loading(model, model_config, model_device)
|
||||||
"Expecting all weights "
|
|
||||||
"to be in the same device for now, got both: "
|
|
||||||
f"{model_device} and {_device}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
model_device = _device
|
logger.warning_once(
|
||||||
|
"model_device is None, skip calling process_weights_after_loading"
|
||||||
if name in existing_param_names:
|
|
||||||
module_name, weight_name = name.rsplit(".", 1)
|
|
||||||
module = named_modules[module_name]
|
|
||||||
setattr(
|
|
||||||
module,
|
|
||||||
weight_name,
|
|
||||||
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# recorded_weight_attr is
|
# Step R5 (workaround for cudagraph): restore the original quantized weights
|
||||||
# {"weight_name": {"weight_attr_key": attr}}
|
# and do a copy_ of the currents weights to the original weights
|
||||||
# e.g.
|
updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
|
||||||
# {
|
for name in model.original_weights_rebuild_keys:
|
||||||
# {
|
if name in original_quantized_weight_dict:
|
||||||
# "layer.0.weight": {
|
original_quantized_weight = original_quantized_weight_dict[name]
|
||||||
# "weight_loader": weight_loader_function_object,
|
updated_quantized_weight = updated_quantized_weights[name]
|
||||||
# "input_dim": 0, ...
|
|
||||||
# },
|
|
||||||
# "layer.1.weight": ...,
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
|
|
||||||
for attr_name, attr in weight_attr_dict.items():
|
|
||||||
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
|
||||||
module = named_modules[module_name]
|
|
||||||
weight = getattr(module, weight_name)
|
|
||||||
if not hasattr(weight, attr_name):
|
|
||||||
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
|
||||||
|
|
||||||
# Step I1: reload bfloat16 / high precision weights
|
module_name, weight_name = name.rsplit(".", 1)
|
||||||
loaded_weights = model.load_weights(
|
module = named_modules[module_name]
|
||||||
model_loader.get_all_weights(model_config, model)
|
setattr(module, weight_name, original_quantized_weight)
|
||||||
)
|
with torch.no_grad():
|
||||||
|
original_quantized_weight.copy_(updated_quantized_weight)
|
||||||
|
|
||||||
# Step I2: online quantize the weights
|
del original_quantized_weight_dict
|
||||||
# manually process weights after loading
|
del named_modules
|
||||||
model.process_weights_after_loading_already_called = False
|
del updated_quantized_weight
|
||||||
process_weights_after_loading(model, model_config, model_device)
|
|
||||||
model.process_weights_after_loading_already_called = True
|
model.process_weights_after_loading_already_called = True
|
||||||
return loaded_weights
|
return updated_params
|
||||||
|
|
||||||
|
return patched_model_load_weights
|
||||||
|
|||||||
@ -88,6 +88,14 @@ def initialize_model(
|
|||||||
def process_weights_after_loading(
|
def process_weights_after_loading(
|
||||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if getattr(model, "process_weights_after_loading_already_called", False):
|
||||||
|
# In case `process_weights_after_loading` is called multiple times
|
||||||
|
# we'll skip it at later times
|
||||||
|
logger.debug_once(
|
||||||
|
"process_weights_after_loading already called for model %s", model
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# to avoid circular dependency
|
# to avoid circular dependency
|
||||||
from vllm.model_executor.model_loader.online_quantization import (
|
from vllm.model_executor.model_loader.online_quantization import (
|
||||||
maybe_save_metadata_and_attributes_for_weight_reloading,
|
maybe_save_metadata_and_attributes_for_weight_reloading,
|
||||||
|
|||||||
@ -21,6 +21,9 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.model_loader.online_quantization import (
|
||||||
|
support_quantized_model_reload_from_hp_weights,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import supports_any_eagle
|
from vllm.model_executor.models.interfaces import supports_any_eagle
|
||||||
from vllm.multimodal import NestedTensors
|
from vllm.multimodal import NestedTensors
|
||||||
@ -316,6 +319,7 @@ class AutoWeightsLoader:
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@support_quantized_model_reload_from_hp_weights
|
||||||
def load_weights(
|
def load_weights(
|
||||||
self,
|
self,
|
||||||
weights: Iterable[tuple[str, torch.Tensor]],
|
weights: Iterable[tuple[str, torch.Tensor]],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user