mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 05:45:27 +08:00
[TPU] Add support for online w8a8 quantization (#22425)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
This commit is contained in:
parent
10a02535d4
commit
9a0c5ded5a
@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
|
|||||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||||
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
|
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
|
||||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||||
|
run_and_track_test 7 "test_tpu_int8.py" \
|
||||||
|
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"
|
||||||
|
|
||||||
# After all tests have been attempted, exit with the overall status.
|
# After all tests have been attempted, exit with the overall status.
|
||||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||||
|
|||||||
73
tests/v1/tpu/test_tpu_int8.py
Normal file
73
tests/v1/tpu/test_tpu_int8.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests whether TPU Int8 computation is enabled correctly.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_tpu_int8.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
|
from vllm.model_executor.layers.quantization.tpu_int8 import (
|
||||||
|
TPUInt8LinearMethod)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from ...models.registry import HF_EXAMPLE_MODELS
|
||||||
|
|
||||||
|
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="TPU Int8 is only enabled for TPUs.")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [10])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"hf_overrides",
|
||||||
|
[
|
||||||
|
# w8a8 dynamic activation
|
||||||
|
{
|
||||||
|
'quantization_config': {
|
||||||
|
'quant_method': 'tpu_int8',
|
||||||
|
'activation_scheme': 'dynamic'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
])
|
||||||
|
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
|
||||||
|
hf_overrides: dict, monkeypatch) -> None:
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
|
activation_scheme = hf_overrides.get('quantization_config',
|
||||||
|
{}).get('activation_scheme')
|
||||||
|
quantize_activation = activation_scheme == 'dynamic'
|
||||||
|
|
||||||
|
# Allows using apply_model
|
||||||
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
# Prevent error from re-initializing cache
|
||||||
|
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"A robot may not injure a human being",
|
||||||
|
"It is only with the heart that one can see rightly;",
|
||||||
|
"The greatest glory in living lies not in never falling,",
|
||||||
|
]
|
||||||
|
answers = [
|
||||||
|
"or, being injured, not kill, except in",
|
||||||
|
"without the heart, one can only see wrongly.",
|
||||||
|
"but in rising every time we fall. - Nelson"
|
||||||
|
]
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
||||||
|
|
||||||
|
def check_model(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not isinstance(module, LinearBase):
|
||||||
|
continue
|
||||||
|
quant_method = module.quant_method
|
||||||
|
assert isinstance(quant_method, TPUInt8LinearMethod)
|
||||||
|
assert quant_method.quantize_activation == quantize_activation
|
||||||
|
|
||||||
|
vllm.apply_model(check_model)
|
||||||
|
outputs = vllm.generate_greedy(prompts, max_tokens)
|
||||||
|
for (_, output), answer in zip(outputs, answers):
|
||||||
|
assert answer in output
|
||||||
@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter
|
from vllm.model_executor.parameter import ModelWeightParameter
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["none"]
|
ACTIVATION_SCHEMES = ["none", "dynamic"]
|
||||||
|
|
||||||
|
|
||||||
class Int8TpuConfig(QuantizationConfig):
|
class Int8TpuConfig(QuantizationConfig):
|
||||||
@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: Int8TpuConfig):
|
def __init__(self, quant_config: Int8TpuConfig):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.quantize_activation = False
|
||||||
|
if self.quant_config.activation_scheme == 'dynamic':
|
||||||
|
self.quantize_activation = True
|
||||||
|
|
||||||
def create_weights(self, layer: Module, input_size_per_partition: int,
|
def create_weights(self, layer: Module, input_size_per_partition: int,
|
||||||
output_partition_sizes: list[int], input_size: int,
|
output_partition_sizes: list[int], input_size: int,
|
||||||
@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
try:
|
try:
|
||||||
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
|
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install torch_xla by following the instructions at "
|
"Please install torch_xla by following the instructions at "
|
||||||
@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
|||||||
"to run vLLM on TPU.") from err
|
"to run vLLM on TPU.") from err
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
scale = layer.scale
|
scale = layer.scale
|
||||||
out = torch.ops.xla.quantized_matmul(x, weight, scale)
|
out = torch.ops.xla.quantized_matmul_int8(
|
||||||
|
x, weight, scale, quantize_activation=self.quantize_activation)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
return out
|
return out
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user