mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests whether TPU Int8 computation is enabled correctly.
|
|
|
|
Run `pytest tests/quantization/test_tpu_int8.py`.
|
|
"""
|
|
import pytest
|
|
|
|
from vllm.model_executor.layers.linear import LinearBase
|
|
from vllm.model_executor.layers.quantization.tpu_int8 import (
|
|
TPUInt8LinearMethod)
|
|
from vllm.platforms import current_platform
|
|
|
|
from ...models.registry import HF_EXAMPLE_MODELS
|
|
|
|
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
|
reason="TPU Int8 is only enabled for TPUs.")
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
|
@pytest.mark.parametrize("max_tokens", [10])
|
|
@pytest.mark.parametrize(
|
|
"hf_overrides",
|
|
[
|
|
# w8a8 dynamic activation
|
|
{
|
|
'quantization_config': {
|
|
'quant_method': 'tpu_int8',
|
|
'activation_scheme': 'dynamic'
|
|
}
|
|
}
|
|
])
|
|
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
|
|
hf_overrides: dict, monkeypatch) -> None:
|
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
|
model_info.check_transformers_version(on_fail="skip")
|
|
|
|
activation_scheme = hf_overrides.get('quantization_config',
|
|
{}).get('activation_scheme')
|
|
quantize_activation = activation_scheme == 'dynamic'
|
|
|
|
# Allows using apply_model
|
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
|
# Prevent error from re-initializing cache
|
|
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
|
|
|
|
prompts = [
|
|
"A robot may not injure a human being",
|
|
"It is only with the heart that one can see rightly;",
|
|
"The greatest glory in living lies not in never falling,",
|
|
]
|
|
answers = [
|
|
"or, being injured, not kill, except in",
|
|
"without the heart, one can only see wrongly.",
|
|
"but in rising every time we fall. - Nelson"
|
|
]
|
|
|
|
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
|
|
|
def check_model(model):
|
|
for name, module in model.named_modules():
|
|
if not isinstance(module, LinearBase):
|
|
continue
|
|
quant_method = module.quant_method
|
|
assert isinstance(quant_method, TPUInt8LinearMethod)
|
|
assert quant_method.quantize_activation == quantize_activation
|
|
|
|
vllm.apply_model(check_model)
|
|
outputs = vllm.generate_greedy(prompts, max_tokens)
|
|
for (_, output), answer in zip(outputs, answers):
|
|
assert answer in output
|