mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 16:17:27 +08:00
Remove hardcoded path
This commit is contained in:
parent
84284302d8
commit
092e3d6d6d
@ -43,12 +43,17 @@ class TPUModelRunner:
|
||||
# )
|
||||
self.compiled_fn = self._execute_step
|
||||
|
||||
def load_model(self) -> None:
|
||||
def load_model(self) -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm.model_executor.models.jax.gemma import Transformer
|
||||
|
||||
assert self.model_config.hf_config.model_type == "gemma"
|
||||
self.model = Transformer(self.model_config.hf_config)
|
||||
self.params = load_and_format_params(
|
||||
"/home/woosukk/.cache/huggingface/hub/models--google--gemma-7b-flax/snapshots/255139998d76ac69e797fd4b4e8c4b562dc3c75f/7b")
|
||||
|
||||
model_name = "google/gemma-7b-flax"
|
||||
model_dir = snapshot_download(model_name)
|
||||
self.params = load_and_format_params(model_dir + "/7b/")
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user