mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Fix] Avoid pickling entire LLMEngine for Ray workers (#3207)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
8999ec3c16
commit
2efce05dc3
@ -158,6 +158,11 @@ class LLMEngine:
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
|
||||
def __reduce__(self):
|
||||
# This is to ensure that the LLMEngine is not referenced in
|
||||
# the closure used to initialize Ray worker actors
|
||||
raise RuntimeError("LLMEngine should not be pickled!")
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
@ -280,6 +285,8 @@ class LLMEngine:
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
lora_config = copy.deepcopy(self.lora_config)
|
||||
kv_cache_dtype = self.cache_config.cache_dtype
|
||||
|
||||
for rank, (worker, (node_id,
|
||||
_)) in enumerate(zip(self.workers,
|
||||
@ -295,22 +302,22 @@ class LLMEngine:
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
))
|
||||
|
||||
driver_rank = 0
|
||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||
self.driver_worker = Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user