Remove hardcoded path

This commit is contained in:
Woosuk Kwon 2024-04-19 08:18:10 +00:00
parent 84284302d8
commit 092e3d6d6d

View File

@ -43,12 +43,17 @@ class TPUModelRunner:
# )
self.compiled_fn = self._execute_step
def load_model(self) -> None:
def load_model(self) -> None:
from huggingface_hub import snapshot_download
from vllm.model_executor.models.jax.gemma import Transformer
assert self.model_config.hf_config.model_type == "gemma"
self.model = Transformer(self.model_config.hf_config)
self.params = load_and_format_params(
"/home/woosukk/.cache/huggingface/hub/models--google--gemma-7b-flax/snapshots/255139998d76ac69e797fd4b4e8c4b562dc3c75f/7b")
model_name = "google/gemma-7b-flax"
model_dir = snapshot_download(model_name)
self.params = load_and_format_params(model_dir + "/7b/")
def _prepare_prompt(
self,