mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 04:25:01 +08:00
Support using Int4PreshuffledTensor after loading (#26066)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
parent
2ec401bc39
commit
03c4c4aa9d
@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
def test_on_the_fly_quant_config_dict_json(vllm_runner):
|
||||
def test_online_quant_config_dict_json(vllm_runner):
|
||||
"""Testing on the fly quantization, load_weights integration point,
|
||||
with config dict serialized to json string
|
||||
"""
|
||||
@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
def test_on_the_fly_quant_config_file(vllm_runner):
|
||||
def test_online_quant_config_file(vllm_runner):
|
||||
"""Testing on the fly quantization, load_weights integration point,
|
||||
with config file
|
||||
"""
|
||||
@ -255,5 +255,147 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
@pytest.mark.skip(
|
||||
reason="since torchao nightly is only compatible with torch nightly"
|
||||
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
|
||||
"torchao tests that requires newer versions (0.14.0.dev+) for now"
|
||||
)
|
||||
def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch):
|
||||
"""We load a model with Int4Tensor (plain format) linear weights
|
||||
and verify that the weight is updated to Int4PreshuffledTensor
|
||||
after loading in vllm
|
||||
"""
|
||||
from torchao.quantization import Int4PreshuffledTensor
|
||||
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
|
||||
|
||||
torch._dynamo.reset()
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev"
|
||||
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
|
||||
# have meta kernel implemented yet, can remove this flag after that is implemented
|
||||
with vllm_runner(
|
||||
model_name=model_name,
|
||||
quantization="torchao",
|
||||
dtype="bfloat16",
|
||||
pt_load_map_location="cuda:0",
|
||||
enforce_eager=True,
|
||||
) as llm:
|
||||
|
||||
def has_int4_preshuffled_tensor_weight(model):
|
||||
return isinstance(
|
||||
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
|
||||
Int4PreshuffledTensor,
|
||||
)
|
||||
|
||||
def get_weight_attrs(model):
|
||||
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
|
||||
return [
|
||||
weight.requires_grad,
|
||||
weight.input_dim,
|
||||
weight.output_dim,
|
||||
hasattr(weight, "weight_loader"),
|
||||
]
|
||||
|
||||
llm_engine = llm.get_llm().llm_engine
|
||||
has_int4_preshuffled_tensor = any(
|
||||
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
|
||||
)
|
||||
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
|
||||
|
||||
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
|
||||
# fbgemm_gpu_genai
|
||||
# library is installed, otherwise it should be using Int4Tensor
|
||||
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
|
||||
assert has_int4_preshuffled_tensor
|
||||
else:
|
||||
assert not has_int4_preshuffled_tensor
|
||||
|
||||
assert weight_attrs == [False, 1, 0, True]
|
||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
|
||||
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
@pytest.mark.skip(
|
||||
reason="since torchao nightly is only compatible with torch nightly"
|
||||
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
|
||||
"torchao tests that requires newer versions (0.14.0.dev+) for now"
|
||||
)
|
||||
def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant(
|
||||
vllm_runner, monkeypatch
|
||||
):
|
||||
"""We load a bf16 model and online quantize the model to int4, then verify that
|
||||
the weights are updated to Int4PreshuffledTensor after online quantization
|
||||
"""
|
||||
from torchao.quantization import Int4PreshuffledTensor
|
||||
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
|
||||
|
||||
torch._dynamo.reset()
|
||||
model_name = "facebook/opt-125m"
|
||||
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
import json
|
||||
|
||||
from torchao.core.config import config_to_dict
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
torchao_quant_config = Int4WeightOnlyConfig(
|
||||
group_size=128, int4_packing_format="plain"
|
||||
)
|
||||
hf_overrides = {
|
||||
"quantization_config_dict_json": json.dumps(
|
||||
config_to_dict(torchao_quant_config)
|
||||
)
|
||||
}
|
||||
|
||||
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
|
||||
# have meta kernel implemented yet, can remove this flag after that is implemented
|
||||
with vllm_runner(
|
||||
model_name=model_name,
|
||||
quantization="torchao",
|
||||
dtype="bfloat16",
|
||||
pt_load_map_location="cuda:0",
|
||||
hf_overrides=hf_overrides,
|
||||
enforce_eager=True,
|
||||
) as llm:
|
||||
|
||||
def has_int4_preshuffled_tensor_weight(model):
|
||||
return isinstance(
|
||||
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
|
||||
Int4PreshuffledTensor,
|
||||
)
|
||||
|
||||
def get_weight_attrs(model):
|
||||
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
|
||||
return [
|
||||
weight.requires_grad,
|
||||
weight.input_dim,
|
||||
weight.output_dim,
|
||||
hasattr(weight, "weight_loader"),
|
||||
]
|
||||
|
||||
llm_engine = llm.get_llm().llm_engine
|
||||
has_int4_preshuffled_tensor = any(
|
||||
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
|
||||
)
|
||||
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
|
||||
|
||||
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
|
||||
# fbgemm_gpu_genai
|
||||
# library is installed, otherwise it should be using Int4Tensor
|
||||
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
|
||||
assert has_int4_preshuffled_tensor
|
||||
else:
|
||||
assert not has_int4_preshuffled_tensor
|
||||
|
||||
assert weight_attrs == [False, 1, 0, True]
|
||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
|
||||
|
||||
assert output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import json
|
||||
import types
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -27,6 +28,39 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _bond_method_to_cls(func, obj):
|
||||
if hasattr(func, "__self__") or not callable(func):
|
||||
# If the function is already bound to an instance, return it as is
|
||||
return func
|
||||
else:
|
||||
return types.MethodType(func, obj)
|
||||
|
||||
|
||||
def _get_weight_attrs(param):
|
||||
# record attributes attached to the weight, so we can
|
||||
# recover later
|
||||
recorded_weight_attr = {}
|
||||
for key in param.__dict__:
|
||||
if hasattr(param, key):
|
||||
attr = getattr(param, key)
|
||||
if not callable(attr):
|
||||
recorded_weight_attr[key] = attr
|
||||
elif hasattr(attr, "__self__") and param is attr.__self__:
|
||||
# if attr is a bonded method for an instance, and
|
||||
# attr.__self__ points to the instance (param)
|
||||
# we'll record the underlying function object
|
||||
recorded_weight_attr[key] = attr.__func__
|
||||
else:
|
||||
recorded_weight_attr[key] = attr
|
||||
return recorded_weight_attr
|
||||
|
||||
|
||||
def _restore_weight_attrs(param, recorded_weight_attr):
|
||||
for attr_name, attr in recorded_weight_attr.items():
|
||||
if not hasattr(param, attr_name):
|
||||
setattr(param, attr_name, _bond_method_to_cls(attr, param))
|
||||
|
||||
|
||||
def torchao_version_at_least(torchao_version: str) -> bool:
|
||||
if find_spec("torchao"):
|
||||
try:
|
||||
@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
if torchao_version_at_least("0.15.0"):
|
||||
from torchao.prototype.tensor_conversion.api import (
|
||||
convert_to_packed_tensor_based_on_current_hardware,
|
||||
)
|
||||
else:
|
||||
convert_to_packed_tensor_based_on_current_hardware = lambda t: t
|
||||
|
||||
|
||||
class TorchAOConfig(QuantizationConfig):
|
||||
"""Config class for torchao."""
|
||||
|
||||
@ -307,12 +349,32 @@ class TorchAOLinearMethod(LinearMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.quant_config.is_checkpoint_torchao_serialized:
|
||||
if not hasattr(layer, "weight"):
|
||||
return
|
||||
|
||||
# quantize the weight on the fly if the checkpoint is not already
|
||||
# record attributes attached to the weight, so we can
|
||||
# recover later
|
||||
recorded_weight_attr = _get_weight_attrs(layer.weight)
|
||||
|
||||
layer.weight = Parameter(
|
||||
convert_to_packed_tensor_based_on_current_hardware(layer.weight),
|
||||
requires_grad=layer.weight.requires_grad,
|
||||
)
|
||||
|
||||
_restore_weight_attrs(layer.weight, recorded_weight_attr)
|
||||
return
|
||||
|
||||
# online quantize the weight if the checkpoint is not already
|
||||
# quantized by torchao
|
||||
recorded_weight_attr = _get_weight_attrs(layer.weight)
|
||||
|
||||
weight = torchao_quantize_param_data(
|
||||
layer.weight, self.quant_config.torchao_config
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
weight = torch.nn.Parameter(
|
||||
convert_to_packed_tensor_based_on_current_hardware(weight),
|
||||
weight.requires_grad,
|
||||
)
|
||||
|
||||
_restore_weight_attrs(weight, recorded_weight_attr)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user