diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 16a9f0959b5c5..5da481baeeea7 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -101,7 +101,10 @@ class TPUWorker: # fix this. It will be removed after the bug in XLA compiler is fixed. os.environ["LIBTPU_INIT_ARGS"] = ( os.environ.get("LIBTPU_INIT_ARGS", "") + - " --xla_tpu_force_1d_allreduce_at_chunk_count=1") + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" + " --xla_jf_conv_input_fusion=False") + # --xla_jf_conv_input_fusion=False is used to improve the perf of + # quantized matmul. torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype)