diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index d3da79a676e47..3278c5dada3b3 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -43,12 +43,17 @@ class TPUModelRunner: # ) self.compiled_fn = self._execute_step - def load_model(self) -> None: + def load_model(self) -> None: + from huggingface_hub import snapshot_download + from vllm.model_executor.models.jax.gemma import Transformer + assert self.model_config.hf_config.model_type == "gemma" self.model = Transformer(self.model_config.hf_config) - self.params = load_and_format_params( - "/home/woosukk/.cache/huggingface/hub/models--google--gemma-7b-flax/snapshots/255139998d76ac69e797fd4b4e8c4b562dc3c75f/7b") + + model_name = "google/gemma-7b-flax" + model_dir = snapshot_download(model_name) + self.params = load_and_format_params(model_dir + "/7b/") def _prepare_prompt( self,