From 9af6d22e4c8ef79972911efaab2a68f4b03def2a Mon Sep 17 00:00:00 2001 From: XiongfeiWei Date: Mon, 9 Jun 2025 18:28:45 -0700 Subject: [PATCH] Use xla flag to improve the quantized model performance (#19303) Signed-off-by: Xiongfei Wei --- vllm/v1/worker/tpu_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 16a9f0959b5c5..5da481baeeea7 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -101,7 +101,10 @@ class TPUWorker: # fix this. It will be removed after the bug in XLA compiler is fixed. os.environ["LIBTPU_INIT_ARGS"] = ( os.environ.get("LIBTPU_INIT_ARGS", "") + - " --xla_tpu_force_1d_allreduce_at_chunk_count=1") + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" + " --xla_jf_conv_input_fusion=False") + # --xla_jf_conv_input_fusion=False is used to improve the perf of + # quantized matmul. torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype)