[core] fix sleep mode in pytorch 2.6 (#13456)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-18 13:48:10 +08:00 committed by GitHub
parent a1074b3efe
commit ac19b519ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,7 @@
# the only successful approach is to call cuda driver API in C. # the only successful approach is to call cuda driver API in C.
import dataclasses import dataclasses
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool): with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool yield mem_pool, new_alloc
class CuMemAllocator: class CuMemAllocator:
@ -142,6 +142,7 @@ class CuMemAllocator:
def __init__(self): def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {} self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
@ -231,7 +232,13 @@ class CuMemAllocator:
old_tag = self.current_tag old_tag = self.current_tag
self.current_tag = tag self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback, with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback): self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error # PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see # when using pluggable allocator, see