Move online quantization to model.load_weights (#26327)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang 2025-11-18 16:52:41 -08:00 committed by GitHub
parent 1395461f5f
commit da94c7c0eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 309 additions and 108 deletions

View File

@ -62,7 +62,7 @@ ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(

View File

@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
tensor-parallel vLLM inference engine occupies GPU 12.
The example performs the following steps:
* Load the training model on GPU 0.
* Split the inference model across GPUs 12 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.
For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import json
import os
import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from torchao.core.config import config_to_dict
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
PerRow,
)
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.utils.network_utils import get_ip, get_open_port
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
super().__init__(*args, **kwargs)
# Load the OPT-125M model onto GPU 0 for the training workload.
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# generate torchao quantization config for RL rollout
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
# use serialized config files instead of passing around json string
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
json_str = json.dumps(config_to_dict(config))
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model="facebook/opt-125m",
hf_overrides={"quantization_config_dict_json": json_str},
enforce_eager=True,
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Set up the communication channel between the training process and the
# inference engine.
master_address = get_ip()
master_port = get_open_port()
handle = llm.collective_rpc.remote(
"init_weight_update_group", args=(master_address, master_port, 1, 3)
)
model_update_group = stateless_init_process_group(
master_address, master_port, 0, 3, torch.device("cuda:0")
)
ray.get(handle)
# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for name, p in train_model.named_parameters():
p.data.zero_()
# Synchronize the updated weights to the inference engine.
for name, p in train_model.named_parameters():
dtype_name = str(p.dtype).split(".")[-1]
handle = llm.collective_rpc.remote(
"update_weight", args=(name, dtype_name, p.shape)
)
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)
# Verify that the inference weights have been updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
# Generate text with the updated model. The output is expected to be nonsense
# because the weights are zero.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

View File

@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
fastsafetensors_weights_iterator,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
get_quant_config,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao" and torchao_version_at_least(
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
if (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized
and torchao_version_at_least("0.14.0")
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize,
)
# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(

View File

@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
from collections.abc import Iterable
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import process_weights_after_loading
logger = init_logger(__name__)
@ -56,6 +56,9 @@ logger = init_logger(__name__)
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# R5. (workaround for cudagraph), we restore the weight params to original quantized
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
# the weight update work well with cudagraph
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights
@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
if model_config.quantization != "torchao":
return
if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.warning(
"process_weights_after_loading already called for model %s", model
)
return
from vllm.model_executor.model_loader.weight_utils import get_quant_config
quant_config = get_quant_config(model_config, None)
@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
else:
model.recorded_weight_attr[name][key] = attr
# mark the metadata and attributes saved so we don't run it again
model._model_config = model_config
model.weight_metadata_and_attr_saved = True
@ -148,77 +144,132 @@ def _bond_method_to_cls(func, obj):
return types.MethodType(func, obj)
def load_weights_and_online_quantize(
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
) -> set[str]:
def support_quantized_model_reload_from_hp_weights(original_load_weights):
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights
"""
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4 in the Notes
# R1, R2, R3, R4, R5 in the Notes
# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
def patched_model_load_weights(
auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
) -> set[str]:
model = auto_weight_loader.module
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see Notes in this file for more details
if offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
return original_load_weights(auto_weight_loader, weights, mapper=mapper)
# Step R2: recover the parameter to the state before first loading
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
model_config = model._model_config
# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
# TODO: maybe set remove_duplicate to True?
original_quantized_weight_dict = dict(
model.named_parameters(remove_duplicate=False)
)
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
else:
model_device = _device
if name in original_quantized_weight_dict:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(
torch.empty(_shape, dtype=_dtype, device=_device),
requires_grad=False,
),
)
# Step R2: recover the weight attributes to the state before first loading
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step R3: reload bfloat16 / high precision weights
updated_params = original_load_weights(
auto_weight_loader, weights, mapper=mapper
)
# Step R4: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
process_weights_after_loading(model, model_config, model_device)
else:
model_device = _device
if name in existing_param_names:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
logger.warning_once(
"model_device is None, skip calling process_weights_after_loading"
)
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step R5 (workaround for cudagraph): restore the original quantized weights
# and do a copy_ of the currents weights to the original weights
updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
for name in model.original_weights_rebuild_keys:
if name in original_quantized_weight_dict:
original_quantized_weight = original_quantized_weight_dict[name]
updated_quantized_weight = updated_quantized_weights[name]
# Step I1: reload bfloat16 / high precision weights
loaded_weights = model.load_weights(
model_loader.get_all_weights(model_config, model)
)
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(module, weight_name, original_quantized_weight)
with torch.no_grad():
original_quantized_weight.copy_(updated_quantized_weight)
# Step I2: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
process_weights_after_loading(model, model_config, model_device)
model.process_weights_after_loading_already_called = True
return loaded_weights
del original_quantized_weight_dict
del named_modules
del updated_quantized_weight
model.process_weights_after_loading_already_called = True
return updated_params
return patched_model_load_weights

View File

@ -88,6 +88,14 @@ def initialize_model(
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.debug_once(
"process_weights_after_loading already called for model %s", model
)
return
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
maybe_save_metadata_and_attributes_for_weight_reloading,

View File

@ -21,6 +21,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.model_executor.model_loader.online_quantization import (
support_quantized_model_reload_from_hp_weights,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.multimodal import NestedTensors
@ -316,6 +319,7 @@ class AutoWeightsLoader:
)
raise ValueError(msg)
@support_quantized_model_reload_from_hp_weights
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],