mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 00:17:54 +08:00
Remove hardcoded path
This commit is contained in:
parent
84284302d8
commit
092e3d6d6d
@ -43,12 +43,17 @@ class TPUModelRunner:
|
|||||||
# )
|
# )
|
||||||
self.compiled_fn = self._execute_step
|
self.compiled_fn = self._execute_step
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from vllm.model_executor.models.jax.gemma import Transformer
|
from vllm.model_executor.models.jax.gemma import Transformer
|
||||||
|
|
||||||
|
assert self.model_config.hf_config.model_type == "gemma"
|
||||||
self.model = Transformer(self.model_config.hf_config)
|
self.model = Transformer(self.model_config.hf_config)
|
||||||
self.params = load_and_format_params(
|
|
||||||
"/home/woosukk/.cache/huggingface/hub/models--google--gemma-7b-flax/snapshots/255139998d76ac69e797fd4b4e8c4b562dc3c75f/7b")
|
model_name = "google/gemma-7b-flax"
|
||||||
|
model_dir = snapshot_download(model_name)
|
||||||
|
self.params = load_and_format_params(model_dir + "/7b/")
|
||||||
|
|
||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user