Use xla flag to improve the quantized model performance (#19303)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
XiongfeiWei 2025-06-09 18:28:45 -07:00 committed by GitHub
parent 4589b94032
commit 9af6d22e4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -101,7 +101,10 @@ class TPUWorker:
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
os.environ.get("LIBTPU_INIT_ARGS", "") +
" --xla_tpu_force_1d_allreduce_at_chunk_count=1")
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
" --xla_jf_conv_input_fusion=False")
# --xla_jf_conv_input_fusion=False is used to improve the perf of
# quantized matmul.
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)