mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 17:37:08 +08:00
[Bugfix] Add custom Triton cache manager to resolve MoE MP issue (#6140)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Chih-Chieh-Yang <chih.chieh.yang@ibm.com>
This commit is contained in:
parent
a63a4c6341
commit
eaec4b9153
@ -9,6 +9,7 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
|||||||
ResultHandler, WorkerMonitor)
|
ResultHandler, WorkerMonitor)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
from vllm.triton_utils import maybe_set_triton_cache_manager
|
||||||
from vllm.utils import (cuda_device_count_stateless,
|
from vllm.utils import (cuda_device_count_stateless,
|
||||||
error_on_invalid_device_count_status,
|
error_on_invalid_device_count_status,
|
||||||
get_distributed_init_method, get_open_port,
|
get_distributed_init_method, get_open_port,
|
||||||
@ -42,6 +43,10 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
if "OMP_NUM_THREADS" not in os.environ:
|
if "OMP_NUM_THREADS" not in os.environ:
|
||||||
os.environ["OMP_NUM_THREADS"] = "1"
|
os.environ["OMP_NUM_THREADS"] = "1"
|
||||||
|
|
||||||
|
# workaround for https://github.com/vllm-project/vllm/issues/6103
|
||||||
|
if world_size > 1:
|
||||||
|
maybe_set_triton_cache_manager()
|
||||||
|
|
||||||
assert world_size <= cuda_device_count_stateless(), (
|
assert world_size <= cuda_device_count_stateless(), (
|
||||||
"please set tensor_parallel_size to less than max local gpu count")
|
"please set tensor_parallel_size to less than max local gpu count")
|
||||||
|
|
||||||
|
|||||||
6
vllm/triton_utils/__init__.py
Normal file
6
vllm/triton_utils/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from vllm.triton_utils.custom_cache_manager import (
|
||||||
|
maybe_set_triton_cache_manager)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"maybe_set_triton_cache_manager",
|
||||||
|
]
|
||||||
53
vllm/triton_utils/custom_cache_manager.py
Normal file
53
vllm/triton_utils/custom_cache_manager.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
|
||||||
|
default_dump_dir, default_override_dir)
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_set_triton_cache_manager() -> None:
|
||||||
|
"""Set environment variable to tell Triton to use a
|
||||||
|
custom cache manager"""
|
||||||
|
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
||||||
|
if cache_manger is None:
|
||||||
|
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
|
||||||
|
logger.info("Setting Triton cache manager to: %s", manager)
|
||||||
|
os.environ["TRITON_CACHE_MANAGER"] = manager
|
||||||
|
|
||||||
|
|
||||||
|
class CustomCacheManager(FileCacheManager):
|
||||||
|
"""Re-implements Triton's cache manager, ensuring that a
|
||||||
|
unique cache directory is created for each process. This is
|
||||||
|
needed to avoid collisions when running with tp>1 and
|
||||||
|
using multi-processing as the distributed backend.
|
||||||
|
|
||||||
|
Note this issue was fixed by triton-lang/triton/pull/4295,
|
||||||
|
but the fix is not yet included in triton==v3.0.0. However,
|
||||||
|
it should be included in the subsequent version.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key, override=False, dump=False):
|
||||||
|
self.key = key
|
||||||
|
self.lock_path = None
|
||||||
|
if dump:
|
||||||
|
self.cache_dir = default_dump_dir()
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
elif override:
|
||||||
|
self.cache_dir = default_override_dir()
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
else:
|
||||||
|
# create cache directory if it doesn't exist
|
||||||
|
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
|
||||||
|
"").strip() or default_cache_dir()
|
||||||
|
if self.cache_dir:
|
||||||
|
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Could not create or locate cache dir")
|
||||||
Loading…
x
Reference in New Issue
Block a user