fix "tansformers_module" ModuleNotFoundError when load model with trust_remote_code=True (#871)

This commit is contained in:
Jingru 2023-09-09 08:21:30 +08:00 committed by GitHub
parent 1117aa1411
commit 4042d192f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 2 deletions

View File

@ -153,7 +153,7 @@ class LLMEngine:
placement_group=placement_group,
placement_group_capture_child_tasks=True),
**ray_remote_kwargs,
)(RayWorker).remote()
)(RayWorker).remote(self.model_config.trust_remote_code)
self.workers.append(worker)
# Initialize torch distributed process group for the workers.

View File

@ -11,7 +11,11 @@ try:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None:
def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
# pylint: disable=import-outside-toplevel
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
def init_worker(self, worker_init_fn):