[Fix] Avoid pickling entire LLMEngine for Ray workers (#3207)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Nick Hill 2024-03-05 16:17:20 -08:00 committed by GitHub
parent 8999ec3c16
commit 2efce05dc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -158,6 +158,11 @@ class LLMEngine:
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
def __reduce__(self):
# This is to ensure that the LLMEngine is not referenced in
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
@ -280,6 +285,8 @@ class LLMEngine:
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
kv_cache_dtype = self.cache_config.cache_dtype
for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers,
@ -295,22 +302,22 @@ class LLMEngine:
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
))
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config,
parallel_config,
scheduler_config,
device_config,
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
driver_local_rank,
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=True,
)