From 130ac91b8f0486bfd06060d810767ab9c336d263 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Wed, 7 Aug 2024 16:16:06 +0200 Subject: [PATCH] turn off extensions on non-linux systems --- ext/DistributedExt/DistributedExt.jl | 130 +-------------------- ext/DistributedExt/distributed_pinning.jl | 70 +++++++++++ ext/DistributedExt/distributed_querying.jl | 56 +++++++++ ext/MPIExt/MPIExt.jl | 108 +---------------- ext/MPIExt/mpi_pinning.jl | 38 ++++++ ext/MPIExt/mpi_querying.jl | 65 +++++++++++ 6 files changed, 235 insertions(+), 232 deletions(-) create mode 100644 ext/DistributedExt/distributed_pinning.jl create mode 100644 ext/DistributedExt/distributed_querying.jl create mode 100644 ext/MPIExt/mpi_pinning.jl create mode 100644 ext/MPIExt/mpi_querying.jl diff --git a/ext/DistributedExt/DistributedExt.jl b/ext/DistributedExt/DistributedExt.jl index 6a5950f5..aac1153d 100644 --- a/ext/DistributedExt/DistributedExt.jl +++ b/ext/DistributedExt/DistributedExt.jl @@ -3,133 +3,9 @@ module DistributedExt import ThreadPinning: ThreadPinning using Distributed: Distributed -function getworkerpids(; include_master = false) - workers = Distributed.workers() - if include_master && !in(1, workers) - pushfirst!(workers, 1) - end - return workers -end - -# querying -function ThreadPinning.distributed_getcpuids(; include_master = false) - res = Dict{Int, Vector{Int}}() - for w in getworkerpids(; include_master) - res[w] = Distributed.@fetchfrom w ThreadPinning.getcpuids() - end - return res -end - -function ThreadPinning.distributed_gethostnames(; include_master = false) - res = Dict{Int, String}() - for w in getworkerpids(; include_master) - res[w] = Distributed.@fetchfrom w gethostname() - end - return res -end - -function ThreadPinning.distributed_getispinned(; include_master = false) - res = Dict{Int, Vector{Bool}}() - for w in getworkerpids(; include_master) - res[w] = Distributed.@fetchfrom w ThreadPinning.getispinned() - end - return res -end - -function compute_distributed_topology(hostnames_dict) - dist_topo = Vector{@NamedTuple{ - pid::Int64, localid::Int64, node::Int64, nodename::String}}( - undef, length(hostnames_dict)) - sorted_by_pid = sortperm(collect(keys(hostnames_dict))) - nodes = unique(collect(values(hostnames_dict))[sorted_by_pid]) - idx = 1 - for (inode, node) in enumerate(nodes) - workers_onnode = collect(keys(filter(p -> p[2] == node, hostnames_dict))) - sort!(workers_onnode) # on each node we sort by worker pid - for (i, r) in enumerate(workers_onnode) - dist_topo[idx] = (; pid = r, localid = i - 1, node = inode, nodename = node) - idx += 1 - end - end - return dist_topo -end - -function ThreadPinning.distributed_topology(; include_master = false) - hostnames_dict = ThreadPinning.distributed_gethostnames(; include_master) - dist_topo = compute_distributed_topology(hostnames_dict) - return dist_topo -end - -# pinning -function ThreadPinning.distributed_pinthreads(symb::Symbol; - compact = false, - include_master = false, - kwargs...) - domain_symbol2functions(symb) # to check input arg as early as possible - dist_topo = ThreadPinning.distributed_topology(; include_master) - @sync for worker in dist_topo - Distributed.remotecall( - () -> ThreadPinning._distributed_pinyourself( - symb, dist_topo; compact, kwargs...), - worker.pid) - end - return -end - -function ThreadPinning._distributed_pinyourself(symb, dist_topo; compact, kwargs...) - # println("_distributed_pinyourself START") - idx = findfirst(w -> w.pid == Distributed.myid(), dist_topo) - if isnothing(idx) - error("Couldn't find myself (worker pid $(Distributed.myid())) in distributed topology.") - end - localid = dist_topo[idx].localid - domain, ndomain = domain_symbol2functions(symb) - # compute cpuids - cpuids = cpuids_of_localid(localid, domain, ndomain; compact) - # actual pinning - ThreadPinning.pinthreads(cpuids; kwargs...) - # println("_distributed_pinyourself STOP") - return -end - -function domain_symbol2functions(symb) - if symb == :sockets - domain = ThreadPinning.socket - ndomain = ThreadPinning.nsockets - elseif symb == :numa - domain = ThreadPinning.numa - ndomain = ThreadPinning.nnuma - elseif symb == :cores - domain = ThreadPinning.core - ndomain = ThreadPinning.ncores - else - throw(ArgumentError("Invalid symbol. Supported symbols are :sockets, :numa, and :cores.")) - end - return domain, ndomain -end - -function cpuids_of_localid(localrank, domain, ndomain; - nthreads_per_proc = Threads.nthreads(), - compact = false) - i_in_domain, idomain = divrem(localrank, ndomain()) .+ 1 - idcs = ((i_in_domain - 1) * nthreads_per_proc + 1):(i_in_domain * nthreads_per_proc) - if maximum(idcs) > length(domain(idomain)) - @show maximum(idcs), length(domain(idomain)) - error("Too many Julia threads / Julia workers for the selected domain.") - end - if domain == ThreadPinning.core - cpuids = domain(idomain, idcs) - else - cpuids = domain(idomain, idcs; compact) - end - return cpuids -end - -function ThreadPinning.distributed_unpinthreads(; include_master = false, kwargs...) - @sync for w in getworkerpids(; include_master) - Distributed.@spawnat w ThreadPinning.unpinthreads(; kwargs...) - end - return +@static if Sys.islinux() + include("distributed_querying.jl") + include("distributed_pinning.jl") end end # module diff --git a/ext/DistributedExt/distributed_pinning.jl b/ext/DistributedExt/distributed_pinning.jl new file mode 100644 index 00000000..4a7ded4e --- /dev/null +++ b/ext/DistributedExt/distributed_pinning.jl @@ -0,0 +1,70 @@ +function ThreadPinning.distributed_pinthreads(symb::Symbol; + compact = false, + include_master = false, + kwargs...) + domain_symbol2functions(symb) # to check input arg as early as possible + dist_topo = ThreadPinning.distributed_topology(; include_master) + @sync for worker in dist_topo + Distributed.remotecall( + () -> ThreadPinning._distributed_pinyourself( + symb, dist_topo; compact, kwargs...), + worker.pid) + end + return +end + +function ThreadPinning._distributed_pinyourself(symb, dist_topo; compact, kwargs...) + # println("_distributed_pinyourself START") + idx = findfirst(w -> w.pid == Distributed.myid(), dist_topo) + if isnothing(idx) + error("Couldn't find myself (worker pid $(Distributed.myid())) in distributed topology.") + end + localid = dist_topo[idx].localid + domain, ndomain = domain_symbol2functions(symb) + # compute cpuids + cpuids = cpuids_of_localid(localid, domain, ndomain; compact) + # actual pinning + ThreadPinning.pinthreads(cpuids; kwargs...) + # println("_distributed_pinyourself STOP") + return +end + +function domain_symbol2functions(symb) + if symb == :sockets + domain = ThreadPinning.socket + ndomain = ThreadPinning.nsockets + elseif symb == :numa + domain = ThreadPinning.numa + ndomain = ThreadPinning.nnuma + elseif symb == :cores + domain = ThreadPinning.core + ndomain = ThreadPinning.ncores + else + throw(ArgumentError("Invalid symbol. Supported symbols are :sockets, :numa, and :cores.")) + end + return domain, ndomain +end + +function cpuids_of_localid(localrank, domain, ndomain; + nthreads_per_proc = Threads.nthreads(), + compact = false) + i_in_domain, idomain = divrem(localrank, ndomain()) .+ 1 + idcs = ((i_in_domain - 1) * nthreads_per_proc + 1):(i_in_domain * nthreads_per_proc) + if maximum(idcs) > length(domain(idomain)) + @show maximum(idcs), length(domain(idomain)) + error("Too many Julia threads / Julia workers for the selected domain.") + end + if domain == ThreadPinning.core + cpuids = domain(idomain, idcs) + else + cpuids = domain(idomain, idcs; compact) + end + return cpuids +end + +function ThreadPinning.distributed_unpinthreads(; include_master = false, kwargs...) + @sync for w in getworkerpids(; include_master) + Distributed.@spawnat w ThreadPinning.unpinthreads(; kwargs...) + end + return +end diff --git a/ext/DistributedExt/distributed_querying.jl b/ext/DistributedExt/distributed_querying.jl new file mode 100644 index 00000000..0ecd7ea6 --- /dev/null +++ b/ext/DistributedExt/distributed_querying.jl @@ -0,0 +1,56 @@ +function getworkerpids(; include_master = false) + workers = Distributed.workers() + if include_master && !in(1, workers) + pushfirst!(workers, 1) + end + return workers +end + +# querying +function ThreadPinning.distributed_getcpuids(; include_master = false) + res = Dict{Int, Vector{Int}}() + for w in getworkerpids(; include_master) + res[w] = Distributed.@fetchfrom w ThreadPinning.getcpuids() + end + return res +end + +function ThreadPinning.distributed_gethostnames(; include_master = false) + res = Dict{Int, String}() + for w in getworkerpids(; include_master) + res[w] = Distributed.@fetchfrom w gethostname() + end + return res +end + +function ThreadPinning.distributed_getispinned(; include_master = false) + res = Dict{Int, Vector{Bool}}() + for w in getworkerpids(; include_master) + res[w] = Distributed.@fetchfrom w ThreadPinning.getispinned() + end + return res +end + +function compute_distributed_topology(hostnames_dict) + dist_topo = Vector{@NamedTuple{ + pid::Int64, localid::Int64, node::Int64, nodename::String}}( + undef, length(hostnames_dict)) + sorted_by_pid = sortperm(collect(keys(hostnames_dict))) + nodes = unique(collect(values(hostnames_dict))[sorted_by_pid]) + idx = 1 + for (inode, node) in enumerate(nodes) + workers_onnode = collect(keys(filter(p -> p[2] == node, hostnames_dict))) + sort!(workers_onnode) # on each node we sort by worker pid + for (i, r) in enumerate(workers_onnode) + dist_topo[idx] = (; pid = r, localid = i - 1, node = inode, nodename = node) + idx += 1 + end + end + return dist_topo +end + +function ThreadPinning.distributed_topology(; include_master = false) + hostnames_dict = ThreadPinning.distributed_gethostnames(; include_master) + dist_topo = compute_distributed_topology(hostnames_dict) + return dist_topo +end diff --git a/ext/MPIExt/MPIExt.jl b/ext/MPIExt/MPIExt.jl index 1743e6fb..2243b446 100644 --- a/ext/MPIExt/MPIExt.jl +++ b/ext/MPIExt/MPIExt.jl @@ -3,111 +3,9 @@ module MPIExt import ThreadPinning: ThreadPinning using MPI: MPI -# querying -function ThreadPinning.mpi_getcpuids(; comm = MPI.COMM_WORLD, dest = 0) - rank = MPI.Comm_rank(comm) - cpuids_ranks = MPI.Gather(ThreadPinning.getcpuids(), comm; root = dest) - rank != 0 && return - n_per_rank = Threads.nthreads() - return Dict((k - 1) => collect(v) - for (k, v) in enumerate(Iterators.partition(cpuids_ranks, n_per_rank))) -end - -function ThreadPinning.mpi_gethostnames(; comm = MPI.COMM_WORLD, dest = 0) - rank = MPI.Comm_rank(comm) - hostnames_ranks = MPI.gather(gethostname(), comm; root = dest) - rank != 0 && return - return Dict((k - 1) => only(v) - for (k, v) in enumerate(Iterators.partition(hostnames_ranks, 1))) -end - -# function compute_localranks(hostnames_ranks) -# localranks = fill(-1, length(hostnames_ranks)) -# nodes = unique(values(hostnames_ranks)) -# for n in nodes -# ranks_onnode = collect(keys(filter(p -> p[2] == n, hostnames_ranks))) -# sort!(ranks_onnode) # on each node we sort by rank id -# for (i, r) in enumerate(ranks_onnode) -# localranks[r + 1] = i - 1 # -1 because local ranks should start at 0 -# end -# end -# return localranks -# end - -function compute_mpi_topology(hostnames_ranks) - mpi_topo = Vector{@NamedTuple{ - rank::Int64, localrank::Int64, node::Int64, nodename::String}}( - undef, length(hostnames_ranks)) - sorted_by_rank = sortperm(collect(keys(hostnames_ranks))) - nodes = unique(collect(values(hostnames_ranks))[sorted_by_rank]) - for (inode, node) in enumerate(nodes) - ranks_onnode = collect(keys(filter(p -> p[2] == node, hostnames_ranks))) - sort!(ranks_onnode) # on each node we sort by rank id - for (i, r) in enumerate(ranks_onnode) - mpi_topo[r + 1] = (; rank = r, localrank = i - 1, node = inode, nodename = node) - end - end - return mpi_topo -end - -function ThreadPinning.mpi_topology(; comm = MPI.COMM_WORLD) - hostnames_ranks = ThreadPinning.mpi_gethostnames(; comm) - rank = MPI.Comm_rank(comm) - mpi_topo = rank == 0 ? compute_mpi_topology(hostnames_ranks) : nothing - # mpi_topo = MPI.bcast(mpi_topo, comm) - return mpi_topo -end - -function ThreadPinning.mpi_getlocalrank(; comm = MPI.COMM_WORLD) - hostnames_ranks = ThreadPinning.mpi_gethostnames(; comm) - rank = MPI.Comm_rank(comm) - localranks = nothing - if rank == 0 - mpi_topo = compute_mpi_topology(hostnames_ranks) - localranks = [r.localrank for r in mpi_topo] - end - localrank = MPI.scatter(localranks, comm; root = 0) - return localrank -end - -# pinning -function ThreadPinning.mpi_pinthreads(symb::Symbol; - comm = MPI.COMM_WORLD, - compact = false, - nthreads_per_rank = Threads.nthreads(), - kwargs...) - if symb == :sockets - domain = ThreadPinning.socket - ndomain = ThreadPinning.nsockets - elseif symb == :numa - domain = ThreadPinning.numa - ndomain = ThreadPinning.nnuma - elseif symb == :cores - domain = ThreadPinning.core - ndomain = ThreadPinning.ncores - else - throw(ArgumentError("Invalid symbol. Supported symbols are :sockets, :numa, and :cores.")) - end - localrank = ThreadPinning.mpi_getlocalrank(; comm) - cpuids = cpuids_of_localrank(localrank, domain, ndomain; nthreads_per_rank, compact) - ThreadPinning.pinthreads(cpuids; nthreads = nthreads_per_rank, kwargs...) - return -end - -function cpuids_of_localrank( - localrank, domain, ndomain; nthreads_per_rank = Threads.nthreads(), compact = false) - i_in_domain, idomain = divrem(localrank, ndomain()) .+ 1 - idcs = ((i_in_domain - 1) * nthreads_per_rank + 1):(i_in_domain * nthreads_per_rank) - if maximum(idcs) > length(domain(idomain)) - @show maximum(idcs), length(domain(idomain)) - error("Too many Julia threads / MPI ranks for the selected domain.") - end - if domain == ThreadPinning.core - cpuids = domain(idomain, idcs) - else - cpuids = domain(idomain, idcs; compact) - end - return cpuids +@static if Sys.islinux() + include("mpi_querying.jl") + include("mpi_pinning.jl") end end # module diff --git a/ext/MPIExt/mpi_pinning.jl b/ext/MPIExt/mpi_pinning.jl new file mode 100644 index 00000000..94c36eae --- /dev/null +++ b/ext/MPIExt/mpi_pinning.jl @@ -0,0 +1,38 @@ +function ThreadPinning.mpi_pinthreads(symb::Symbol; + comm = MPI.COMM_WORLD, + compact = false, + nthreads_per_rank = Threads.nthreads(), + kwargs...) + if symb == :sockets + domain = ThreadPinning.socket + ndomain = ThreadPinning.nsockets + elseif symb == :numa + domain = ThreadPinning.numa + ndomain = ThreadPinning.nnuma + elseif symb == :cores + domain = ThreadPinning.core + ndomain = ThreadPinning.ncores + else + throw(ArgumentError("Invalid symbol. Supported symbols are :sockets, :numa, and :cores.")) + end + localrank = ThreadPinning.mpi_getlocalrank(; comm) + cpuids = cpuids_of_localrank(localrank, domain, ndomain; nthreads_per_rank, compact) + ThreadPinning.pinthreads(cpuids; nthreads = nthreads_per_rank, kwargs...) + return +end + +function cpuids_of_localrank( + localrank, domain, ndomain; nthreads_per_rank = Threads.nthreads(), compact = false) + i_in_domain, idomain = divrem(localrank, ndomain()) .+ 1 + idcs = ((i_in_domain - 1) * nthreads_per_rank + 1):(i_in_domain * nthreads_per_rank) + if maximum(idcs) > length(domain(idomain)) + @show maximum(idcs), length(domain(idomain)) + error("Too many Julia threads / MPI ranks for the selected domain.") + end + if domain == ThreadPinning.core + cpuids = domain(idomain, idcs) + else + cpuids = domain(idomain, idcs; compact) + end + return cpuids +end diff --git a/ext/MPIExt/mpi_querying.jl b/ext/MPIExt/mpi_querying.jl new file mode 100644 index 00000000..36f0e61d --- /dev/null +++ b/ext/MPIExt/mpi_querying.jl @@ -0,0 +1,65 @@ +function ThreadPinning.mpi_getcpuids(; comm = MPI.COMM_WORLD, dest = 0) + rank = MPI.Comm_rank(comm) + cpuids_ranks = MPI.Gather(ThreadPinning.getcpuids(), comm; root = dest) + rank != 0 && return + n_per_rank = Threads.nthreads() + return Dict((k - 1) => collect(v) + for (k, v) in enumerate(Iterators.partition(cpuids_ranks, n_per_rank))) +end + +function ThreadPinning.mpi_gethostnames(; comm = MPI.COMM_WORLD, dest = 0) + rank = MPI.Comm_rank(comm) + hostnames_ranks = MPI.gather(gethostname(), comm; root = dest) + rank != 0 && return + return Dict((k - 1) => only(v) + for (k, v) in enumerate(Iterators.partition(hostnames_ranks, 1))) +end + +# function compute_localranks(hostnames_ranks) +# localranks = fill(-1, length(hostnames_ranks)) +# nodes = unique(values(hostnames_ranks)) +# for n in nodes +# ranks_onnode = collect(keys(filter(p -> p[2] == n, hostnames_ranks))) +# sort!(ranks_onnode) # on each node we sort by rank id +# for (i, r) in enumerate(ranks_onnode) +# localranks[r + 1] = i - 1 # -1 because local ranks should start at 0 +# end +# end +# return localranks +# end + +function compute_mpi_topology(hostnames_ranks) + mpi_topo = Vector{@NamedTuple{ + rank::Int64, localrank::Int64, node::Int64, nodename::String}}( + undef, length(hostnames_ranks)) + sorted_by_rank = sortperm(collect(keys(hostnames_ranks))) + nodes = unique(collect(values(hostnames_ranks))[sorted_by_rank]) + for (inode, node) in enumerate(nodes) + ranks_onnode = collect(keys(filter(p -> p[2] == node, hostnames_ranks))) + sort!(ranks_onnode) # on each node we sort by rank id + for (i, r) in enumerate(ranks_onnode) + mpi_topo[r + 1] = (; rank = r, localrank = i - 1, node = inode, nodename = node) + end + end + return mpi_topo +end + +function ThreadPinning.mpi_topology(; comm = MPI.COMM_WORLD) + hostnames_ranks = ThreadPinning.mpi_gethostnames(; comm) + rank = MPI.Comm_rank(comm) + mpi_topo = rank == 0 ? compute_mpi_topology(hostnames_ranks) : nothing + # mpi_topo = MPI.bcast(mpi_topo, comm) + return mpi_topo +end + +function ThreadPinning.mpi_getlocalrank(; comm = MPI.COMM_WORLD) + hostnames_ranks = ThreadPinning.mpi_gethostnames(; comm) + rank = MPI.Comm_rank(comm) + localranks = nothing + if rank == 0 + mpi_topo = compute_mpi_topology(hostnames_ranks) + localranks = [r.localrank for r in mpi_topo] + end + localrank = MPI.scatter(localranks, comm; root = 0) + return localrank +end