mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 06:15:02 +08:00
Use w8a8 quantized matmul Pallas kernel (#19170)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
parent
946aadb4a0
commit
d4170fad39
@ -18,9 +18,9 @@ setuptools==78.1.0
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch==2.9.0.dev20250703
|
torch==2.9.0.dev20250711
|
||||||
torchvision==0.24.0.dev20250703
|
torchvision==0.24.0.dev20250711
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ RTOL = 0.03
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GSM8KAccuracyTestConfig:
|
class GSM8KAccuracyTestConfig:
|
||||||
model_name: str
|
model_name: str
|
||||||
excepted_value: float
|
expected_value: float
|
||||||
|
|
||||||
def get_model_args(self) -> str:
|
def get_model_args(self) -> str:
|
||||||
return (f"pretrained={self.model_name},"
|
return (f"pretrained={self.model_name},"
|
||||||
@ -25,13 +25,13 @@ class GSM8KAccuracyTestConfig:
|
|||||||
ACCURACY_CONFIGS = [
|
ACCURACY_CONFIGS = [
|
||||||
GSM8KAccuracyTestConfig(
|
GSM8KAccuracyTestConfig(
|
||||||
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
|
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
|
||||||
excepted_value=0.76), # no bias
|
expected_value=0.76), # no bias
|
||||||
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
|
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
|
||||||
# so only one of these tests can run in a single call to pytest. As
|
# so only one of these tests can run in a single call to pytest. As
|
||||||
# a follow up, move this into the LM-EVAL section of the CI.
|
# a follow up, move this into the LM-EVAL section of the CI.
|
||||||
# GSM8KAccuracyTestConfig(
|
# GSM8KAccuracyTestConfig(
|
||||||
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
|
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
|
||||||
# excepted_value=0.66), # bias in QKV layers
|
# expected_value=0.66), # bias in QKV layers
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
|
|||||||
batch_size="auto",
|
batch_size="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
EXPECTED_VALUE = config.excepted_value
|
EXPECTED_VALUE = config.expected_value
|
||||||
measured_value = results["results"][TASK][FILTER]
|
measured_value = results["results"][TASK][FILTER]
|
||||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||||
and measured_value + RTOL > EXPECTED_VALUE
|
and measured_value + RTOL > EXPECTED_VALUE
|
||||||
|
|||||||
@ -145,3 +145,35 @@ def test_gemma3_27b_with_text_input_and_tp(
|
|||||||
for output, answer in zip(vllm_outputs, answers):
|
for output, answer in zip(vllm_outputs, answers):
|
||||||
generated_text = output[1]
|
generated_text = output[1]
|
||||||
assert answer in generated_text
|
assert answer in generated_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This is a basic test for TPU only")
|
||||||
|
def test_w8a8_quantization(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
|
||||||
|
max_tokens = 5
|
||||||
|
tensor_parallel_size = 1
|
||||||
|
max_num_seqs = 4
|
||||||
|
|
||||||
|
prompt = "The next numbers of the sequence " + ", ".join(
|
||||||
|
str(i) for i in range(1024)) + " are:"
|
||||||
|
example_prompts = [prompt]
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
max_num_batched_tokens=64,
|
||||||
|
max_model_len=4096,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens)
|
||||||
|
output = vllm_outputs[0][1]
|
||||||
|
|
||||||
|
assert "1024" in output or "0, 1" in output
|
||||||
|
|||||||
@ -90,16 +90,15 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||||
|
|
||||||
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
|
# Required to register custom ops.
|
||||||
out = torch.ops.xla.quantized_matmul(x,
|
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||||
|
out = torch.ops.xla.quantized_matmul_int8(
|
||||||
|
x,
|
||||||
w_q,
|
w_q,
|
||||||
w_s,
|
w_s,
|
||||||
zero_point=None,
|
quantize_activation=True,
|
||||||
block_size=-1,
|
)
|
||||||
int4_weight=False,
|
|
||||||
quantize_activation=True)
|
|
||||||
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
|
|
||||||
out = out.to(x.dtype)
|
|
||||||
# Explicitly capture control flow to make dynamo happy.
|
# Explicitly capture control flow to make dynamo happy.
|
||||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user