mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 21:35:01 +08:00
[Fix] Avoid pickling entire LLMEngine for Ray workers (#3207)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
8999ec3c16
commit
2efce05dc3
@ -158,6 +158,11 @@ class LLMEngine:
|
|||||||
if USE_RAY_COMPILED_DAG:
|
if USE_RAY_COMPILED_DAG:
|
||||||
self.forward_dag = self._compiled_ray_dag()
|
self.forward_dag = self._compiled_ray_dag()
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
# This is to ensure that the LLMEngine is not referenced in
|
||||||
|
# the closure used to initialize Ray worker actors
|
||||||
|
raise RuntimeError("LLMEngine should not be pickled!")
|
||||||
|
|
||||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
@ -280,6 +285,8 @@ class LLMEngine:
|
|||||||
parallel_config = copy.deepcopy(self.parallel_config)
|
parallel_config = copy.deepcopy(self.parallel_config)
|
||||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||||
device_config = copy.deepcopy(self.device_config)
|
device_config = copy.deepcopy(self.device_config)
|
||||||
|
lora_config = copy.deepcopy(self.lora_config)
|
||||||
|
kv_cache_dtype = self.cache_config.cache_dtype
|
||||||
|
|
||||||
for rank, (worker, (node_id,
|
for rank, (worker, (node_id,
|
||||||
_)) in enumerate(zip(self.workers,
|
_)) in enumerate(zip(self.workers,
|
||||||
@ -295,22 +302,22 @@ class LLMEngine:
|
|||||||
local_rank,
|
local_rank,
|
||||||
rank,
|
rank,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=lora_config,
|
||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
))
|
))
|
||||||
|
|
||||||
driver_rank = 0
|
driver_rank = 0
|
||||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||||
self.driver_worker = Worker(
|
self.driver_worker = Worker(
|
||||||
model_config,
|
self.model_config,
|
||||||
parallel_config,
|
self.parallel_config,
|
||||||
scheduler_config,
|
self.scheduler_config,
|
||||||
device_config,
|
self.device_config,
|
||||||
driver_local_rank,
|
driver_local_rank,
|
||||||
driver_rank,
|
driver_rank,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user