From dd344e03425087ebcfc3f98f91821c7e5d316832 Mon Sep 17 00:00:00 2001 From: yarongmu-google <150371854+yarongmu-google@users.noreply.github.com> Date: Fri, 14 Mar 2025 17:41:15 -0700 Subject: [PATCH] =?UTF-8?q?[Bugfix]=20Fix=20torch=5Fxla=20in=20V0=20which?= =?UTF-8?q?=20can't=20handle=20None=20seed=20introduced=20=E2=80=A6=20(#14?= =?UTF-8?q?844)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Yarong Mu --- vllm/worker/tpu_worker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 1a5eaba09b940..66911790662eb 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -51,6 +51,9 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): self.model_runner: TPUModelRunner = TPUModelRunner( vllm_config=vllm_config, is_driver_worker=is_driver_worker) + if self.model_config.seed is None: + self.model_config.seed = 0 + def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" torch.set_grad_enabled(False)