Support using Int4PreshuffledTensor after loading (#26066)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang 2025-11-04 03:00:57 -08:00 committed by GitHub
parent 2ec401bc39
commit 03c4c4aa9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 208 additions and 4 deletions

View File

@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_on_the_fly_quant_config_dict_json(vllm_runner): def test_online_quant_config_dict_json(vllm_runner):
"""Testing on the fly quantization, load_weights integration point, """Testing on the fly quantization, load_weights integration point,
with config dict serialized to json string with config dict serialized to json string
""" """
@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner):
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_on_the_fly_quant_config_file(vllm_runner): def test_online_quant_config_file(vllm_runner):
"""Testing on the fly quantization, load_weights integration point, """Testing on the fly quantization, load_weights integration point,
with config file with config file
""" """
@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
) as llm: ) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=4) output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
assert output
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch):
"""We load a model with Int4Tensor (plain format) linear weights
and verify that the weight is updated to Int4PreshuffledTensor
after loading in vllm
"""
from torchao.quantization import Int4PreshuffledTensor
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
torch._dynamo.reset()
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev"
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
# have meta kernel implemented yet, can remove this flag after that is implemented
with vllm_runner(
model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0",
enforce_eager=True,
) as llm:
def has_int4_preshuffled_tensor_weight(model):
return isinstance(
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
Int4PreshuffledTensor,
)
def get_weight_attrs(model):
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
return [
weight.requires_grad,
weight.input_dim,
weight.output_dim,
hasattr(weight, "weight_loader"),
]
llm_engine = llm.get_llm().llm_engine
has_int4_preshuffled_tensor = any(
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
)
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
# fbgemm_gpu_genai
# library is installed, otherwise it should be using Int4Tensor
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
assert has_int4_preshuffled_tensor
else:
assert not has_int4_preshuffled_tensor
assert weight_attrs == [False, 1, 0, True]
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
assert output
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant(
vllm_runner, monkeypatch
):
"""We load a bf16 model and online quantize the model to int4, then verify that
the weights are updated to Int4PreshuffledTensor after online quantization
"""
from torchao.quantization import Int4PreshuffledTensor
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
torch._dynamo.reset()
model_name = "facebook/opt-125m"
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
import json
from torchao.core.config import config_to_dict
from torchao.quantization import Int4WeightOnlyConfig
torchao_quant_config = Int4WeightOnlyConfig(
group_size=128, int4_packing_format="plain"
)
hf_overrides = {
"quantization_config_dict_json": json.dumps(
config_to_dict(torchao_quant_config)
)
}
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
# have meta kernel implemented yet, can remove this flag after that is implemented
with vllm_runner(
model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0",
hf_overrides=hf_overrides,
enforce_eager=True,
) as llm:
def has_int4_preshuffled_tensor_weight(model):
return isinstance(
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
Int4PreshuffledTensor,
)
def get_weight_attrs(model):
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
return [
weight.requires_grad,
weight.input_dim,
weight.output_dim,
hasattr(weight, "weight_loader"),
]
llm_engine = llm.get_llm().llm_engine
has_int4_preshuffled_tensor = any(
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
)
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
# fbgemm_gpu_genai
# library is installed, otherwise it should be using Int4Tensor
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
assert has_int4_preshuffled_tensor
else:
assert not has_int4_preshuffled_tensor
assert weight_attrs == [False, 1, 0, True]
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
assert output assert output

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib import importlib
import json import json
import types
from importlib.util import find_spec from importlib.util import find_spec
from typing import Any, Optional from typing import Any, Optional
@ -27,6 +28,39 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
def _bond_method_to_cls(func, obj):
if hasattr(func, "__self__") or not callable(func):
# If the function is already bound to an instance, return it as is
return func
else:
return types.MethodType(func, obj)
def _get_weight_attrs(param):
# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr = {}
for key in param.__dict__:
if hasattr(param, key):
attr = getattr(param, key)
if not callable(attr):
recorded_weight_attr[key] = attr
elif hasattr(attr, "__self__") and param is attr.__self__:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
recorded_weight_attr[key] = attr.__func__
else:
recorded_weight_attr[key] = attr
return recorded_weight_attr
def _restore_weight_attrs(param, recorded_weight_attr):
for attr_name, attr in recorded_weight_attr.items():
if not hasattr(param, attr_name):
setattr(param, attr_name, _bond_method_to_cls(attr, param))
def torchao_version_at_least(torchao_version: str) -> bool: def torchao_version_at_least(torchao_version: str) -> bool:
if find_spec("torchao"): if find_spec("torchao"):
try: try:
@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
return False return False
if torchao_version_at_least("0.15.0"):
from torchao.prototype.tensor_conversion.api import (
convert_to_packed_tensor_based_on_current_hardware,
)
else:
convert_to_packed_tensor_based_on_current_hardware = lambda t: t
class TorchAOConfig(QuantizationConfig): class TorchAOConfig(QuantizationConfig):
"""Config class for torchao.""" """Config class for torchao."""
@ -307,12 +349,32 @@ class TorchAOLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.quant_config.is_checkpoint_torchao_serialized: if self.quant_config.is_checkpoint_torchao_serialized:
if not hasattr(layer, "weight"):
return
# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr = _get_weight_attrs(layer.weight)
layer.weight = Parameter(
convert_to_packed_tensor_based_on_current_hardware(layer.weight),
requires_grad=layer.weight.requires_grad,
)
_restore_weight_attrs(layer.weight, recorded_weight_attr)
return return
# quantize the weight on the fly if the checkpoint is not already # online quantize the weight if the checkpoint is not already
# quantized by torchao # quantized by torchao
recorded_weight_attr = _get_weight_attrs(layer.weight)
weight = torchao_quantize_param_data( weight = torchao_quantize_param_data(
layer.weight, self.quant_config.torchao_config layer.weight, self.quant_config.torchao_config
) )
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) weight = torch.nn.Parameter(
convert_to_packed_tensor_based_on_current_hardware(weight),
weight.requires_grad,
)
_restore_weight_attrs(weight, recorded_weight_attr)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)