diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 88bc3c509790c..f2085b73b6a48 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -24,6 +24,8 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { #ifndef VLLM_NUMA_DISABLED std::string init_cpu_threads_env(const std::string& cpu_ids) { bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str()); + TORCH_CHECK(omp_cpu_mask != nullptr, + "Failed to parse CPU string: " + cpu_ids); TORCH_CHECK(omp_cpu_mask->size > 0); std::vector omp_cpu_ids; omp_cpu_ids.reserve(omp_cpu_mask->size); @@ -44,20 +46,12 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { // Memory node binding if (numa_available() != -1) { - int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); std::set node_ids; for (const auto& cpu_id : omp_cpu_ids) { int node_id = numa_node_of_cpu(cpu_id); if (node_id != -1) { node_ids.insert(node_id); } - if (node_id != mem_node_id) { - TORCH_WARN("CPU ", cpu_id, " is on NUMA node ", node_id, ", but CPU ", - omp_cpu_ids.front(), " is on NUMA node ", mem_node_id, - ". All CPUs should be on the same NUMA node for optimal " - "performance. Memory will be bound to NUMA node ", - mem_node_id, "."); - } } // Concatenate all node_ids into a single comma-separated string if (!node_ids.empty()) { @@ -70,7 +64,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { } bitmask* mask = numa_parse_nodestring(node_ids_str.c_str()); - bitmask* src_mask = numa_get_membind(); + bitmask* src_mask = numa_get_mems_allowed(); int pid = getpid(); @@ -83,15 +77,46 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { std::to_string(errno)); } - // restrict memory allocation node. - numa_set_membind(mask); + // Restrict memory allocation to the selected NUMA node(s). + // Enhances memory locality for the threads bound to those NUMA CPUs. + if (node_ids.size() > 1) { + errno = 0; + numa_set_interleave_mask(mask); + if (errno != 0) { + TORCH_WARN("numa_set_interleave_mask failed. errno: " + + std::to_string(errno)); + } else { + TORCH_WARN( + "NUMA binding: Using INTERLEAVE policy for memory " + "allocation across multiple NUMA nodes (nodes: " + + node_ids_str + + "). Memory allocations will be " + "interleaved across the specified NUMA nodes."); + } + } else { + errno = 0; + numa_set_membind(mask); + if (errno != 0) { + TORCH_WARN("numa_set_membind failed. errno: " + + std::to_string(errno)); + } else { + TORCH_WARN( + "NUMA binding: Using MEMBIND policy for memory " + "allocation on the NUMA nodes (" + + node_ids_str + + "). Memory allocations will be " + "strictly bound to these NUMA nodes."); + } + } + numa_set_strict(1); numa_free_nodemask(mask); numa_free_nodemask(src_mask); } else { - TORCH_WARN("numa_parse_nodestring or numa_get_membind failed. errno: " + - std::to_string(errno)); + TORCH_WARN( + "numa_parse_nodestring or numa_get_run_node_mask failed. errno: " + + std::to_string(errno)); } } }