[core] fix sleep mode in pytorch 2.6 (#13456)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-18 13:48:10 +08:00 committed by GitHub
parent a1074b3efe
commit ac19b519ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,7 @@
# the only successful approach is to call cuda driver API in C.
import dataclasses
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool
yield mem_pool, new_alloc
class CuMemAllocator:
@ -142,6 +142,7 @@ class CuMemAllocator:
def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
@ -231,7 +232,13 @@ class CuMemAllocator:
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback):
self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see