diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index c5a48352e3089..5199ba2af024f 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -45,31 +45,54 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { // Memory node binding if (numa_available() != -1) { int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); - // Verify all CPUs are on the same NUMA node - for (size_t i = 1; i < omp_cpu_ids.size(); ++i) { - int node_id = numa_node_of_cpu(omp_cpu_ids[i]); - TORCH_CHECK(node_id == mem_node_id, "CPU ", omp_cpu_ids[i], - " is on NUMA node ", node_id, ", but CPU ", - omp_cpu_ids.front(), " is on NUMA node ", mem_node_id, - ". All CPUs should be on the same NUMA node for optimal " - "performance. Memory will be bound to NUMA node ", - mem_node_id, "."); + std::set node_ids; + for (const auto& cpu_id : omp_cpu_ids) { + int node_id = numa_node_of_cpu(cpu_id); + if (node_id != -1) { + node_ids.insert(node_id); + } + TORCH_WARN(node_id == mem_node_id, "CPU ", cpu_id, " is on NUMA node ", + node_id, ", but CPU ", omp_cpu_ids.front(), + " is on NUMA node ", mem_node_id, + ". All CPUs should be on the same NUMA node for optimal " + "performance. Memory will be bound to NUMA node ", + mem_node_id, "."); } - bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str()); - bitmask* src_mask = numa_get_membind(); + // Concatenate all node_ids into a single comma-separated string + if (!node_ids.empty()) { + std::string node_ids_str; + for (const int node_id : node_ids) { + if (!node_ids_str.empty()) { + node_ids_str += ","; + } + node_ids_str += std::to_string(node_id); + } - int pid = getpid(); + bitmask* mask = numa_parse_nodestring(node_ids_str.c_str()); + bitmask* src_mask = numa_get_membind(); - // move all existing pages to the specified numa node. - *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); - int page_num = numa_migrate_pages(pid, src_mask, mask); - if (page_num == -1) { - TORCH_WARN("numa_migrate_pages failed. errno: " + std::to_string(errno)); + int pid = getpid(); + + if (mask && src_mask) { + // move all existing pages to the specified numa node. + *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); + int page_num = numa_migrate_pages(pid, src_mask, mask); + if (page_num == -1) { + TORCH_WARN("numa_migrate_pages failed. errno: " + + std::to_string(errno)); + } + + // restrict memory allocation node. + numa_set_membind(mask); + numa_set_strict(1); + + numa_free_nodemask(mask); + numa_free_nodemask(src_mask); + } else { + TORCH_WARN("numa_parse_nodestring or numa_get_membind failed. errno: " + + std::to_string(errno)); + } } - - // restrict memory allocation node. - numa_set_membind(mask); - numa_set_strict(1); } // OMP threads binding