Skip to content

Commit

Permalink
turn off extensions on non-linux systems
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Aug 7, 2024
1 parent ae3f79f commit 130ac91
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 232 deletions.
130 changes: 3 additions & 127 deletions ext/DistributedExt/DistributedExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,133 +3,9 @@ module DistributedExt
import ThreadPinning: ThreadPinning
using Distributed: Distributed

function getworkerpids(; include_master = false)
workers = Distributed.workers()
if include_master && !in(1, workers)
pushfirst!(workers, 1)
end
return workers
end

# querying
function ThreadPinning.distributed_getcpuids(; include_master = false)
res = Dict{Int, Vector{Int}}()
for w in getworkerpids(; include_master)
res[w] = Distributed.@fetchfrom w ThreadPinning.getcpuids()
end
return res
end

function ThreadPinning.distributed_gethostnames(; include_master = false)
res = Dict{Int, String}()
for w in getworkerpids(; include_master)
res[w] = Distributed.@fetchfrom w gethostname()
end
return res
end

function ThreadPinning.distributed_getispinned(; include_master = false)
res = Dict{Int, Vector{Bool}}()
for w in getworkerpids(; include_master)
res[w] = Distributed.@fetchfrom w ThreadPinning.getispinned()
end
return res
end

function compute_distributed_topology(hostnames_dict)
dist_topo = Vector{@NamedTuple{
pid::Int64, localid::Int64, node::Int64, nodename::String}}(
undef, length(hostnames_dict))
sorted_by_pid = sortperm(collect(keys(hostnames_dict)))
nodes = unique(collect(values(hostnames_dict))[sorted_by_pid])
idx = 1
for (inode, node) in enumerate(nodes)
workers_onnode = collect(keys(filter(p -> p[2] == node, hostnames_dict)))
sort!(workers_onnode) # on each node we sort by worker pid
for (i, r) in enumerate(workers_onnode)
dist_topo[idx] = (; pid = r, localid = i - 1, node = inode, nodename = node)
idx += 1
end
end
return dist_topo
end

function ThreadPinning.distributed_topology(; include_master = false)
hostnames_dict = ThreadPinning.distributed_gethostnames(; include_master)
dist_topo = compute_distributed_topology(hostnames_dict)
return dist_topo
end

# pinning
function ThreadPinning.distributed_pinthreads(symb::Symbol;
compact = false,
include_master = false,
kwargs...)
domain_symbol2functions(symb) # to check input arg as early as possible
dist_topo = ThreadPinning.distributed_topology(; include_master)
@sync for worker in dist_topo
Distributed.remotecall(
() -> ThreadPinning._distributed_pinyourself(
symb, dist_topo; compact, kwargs...),
worker.pid)
end
return
end

function ThreadPinning._distributed_pinyourself(symb, dist_topo; compact, kwargs...)
# println("_distributed_pinyourself START")
idx = findfirst(w -> w.pid == Distributed.myid(), dist_topo)
if isnothing(idx)
error("Couldn't find myself (worker pid $(Distributed.myid())) in distributed topology.")
end
localid = dist_topo[idx].localid
domain, ndomain = domain_symbol2functions(symb)
# compute cpuids
cpuids = cpuids_of_localid(localid, domain, ndomain; compact)
# actual pinning
ThreadPinning.pinthreads(cpuids; kwargs...)
# println("_distributed_pinyourself STOP")
return
end

function domain_symbol2functions(symb)
if symb == :sockets
domain = ThreadPinning.socket
ndomain = ThreadPinning.nsockets
elseif symb == :numa
domain = ThreadPinning.numa
ndomain = ThreadPinning.nnuma
elseif symb == :cores
domain = ThreadPinning.core
ndomain = ThreadPinning.ncores
else
throw(ArgumentError("Invalid symbol. Supported symbols are :sockets, :numa, and :cores."))
end
return domain, ndomain
end

function cpuids_of_localid(localrank, domain, ndomain;
nthreads_per_proc = Threads.nthreads(),
compact = false)
i_in_domain, idomain = divrem(localrank, ndomain()) .+ 1
idcs = ((i_in_domain - 1) * nthreads_per_proc + 1):(i_in_domain * nthreads_per_proc)
if maximum(idcs) > length(domain(idomain))
@show maximum(idcs), length(domain(idomain))
error("Too many Julia threads / Julia workers for the selected domain.")
end
if domain == ThreadPinning.core
cpuids = domain(idomain, idcs)
else
cpuids = domain(idomain, idcs; compact)
end
return cpuids
end

function ThreadPinning.distributed_unpinthreads(; include_master = false, kwargs...)
@sync for w in getworkerpids(; include_master)
Distributed.@spawnat w ThreadPinning.unpinthreads(; kwargs...)
end
return
@static if Sys.islinux()
include("distributed_querying.jl")
include("distributed_pinning.jl")
end

end # module
70 changes: 70 additions & 0 deletions ext/DistributedExt/distributed_pinning.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
function ThreadPinning.distributed_pinthreads(symb::Symbol;
compact = false,
include_master = false,
kwargs...)
domain_symbol2functions(symb) # to check input arg as early as possible
dist_topo = ThreadPinning.distributed_topology(; include_master)
@sync for worker in dist_topo
Distributed.remotecall(
() -> ThreadPinning._distributed_pinyourself(
symb, dist_topo; compact, kwargs...),
worker.pid)
end
return
end

function ThreadPinning._distributed_pinyourself(symb, dist_topo; compact, kwargs...)
# println("_distributed_pinyourself START")
idx = findfirst(w -> w.pid == Distributed.myid(), dist_topo)
if isnothing(idx)
error("Couldn't find myself (worker pid $(Distributed.myid())) in distributed topology.")
end
localid = dist_topo[idx].localid
domain, ndomain = domain_symbol2functions(symb)
# compute cpuids
cpuids = cpuids_of_localid(localid, domain, ndomain; compact)
# actual pinning
ThreadPinning.pinthreads(cpuids; kwargs...)
# println("_distributed_pinyourself STOP")
return
end

function domain_symbol2functions(symb)
if symb == :sockets
domain = ThreadPinning.socket
ndomain = ThreadPinning.nsockets
elseif symb == :numa
domain = ThreadPinning.numa
ndomain = ThreadPinning.nnuma
elseif symb == :cores
domain = ThreadPinning.core
ndomain = ThreadPinning.ncores
else
throw(ArgumentError("Invalid symbol. Supported symbols are :sockets, :numa, and :cores."))
end
return domain, ndomain
end

function cpuids_of_localid(localrank, domain, ndomain;
nthreads_per_proc = Threads.nthreads(),
compact = false)
i_in_domain, idomain = divrem(localrank, ndomain()) .+ 1
idcs = ((i_in_domain - 1) * nthreads_per_proc + 1):(i_in_domain * nthreads_per_proc)
if maximum(idcs) > length(domain(idomain))
@show maximum(idcs), length(domain(idomain))
error("Too many Julia threads / Julia workers for the selected domain.")
end
if domain == ThreadPinning.core
cpuids = domain(idomain, idcs)
else
cpuids = domain(idomain, idcs; compact)
end
return cpuids
end

function ThreadPinning.distributed_unpinthreads(; include_master = false, kwargs...)
@sync for w in getworkerpids(; include_master)
Distributed.@spawnat w ThreadPinning.unpinthreads(; kwargs...)
end
return
end
56 changes: 56 additions & 0 deletions ext/DistributedExt/distributed_querying.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
function getworkerpids(; include_master = false)
workers = Distributed.workers()
if include_master && !in(1, workers)
pushfirst!(workers, 1)
end
return workers
end

# querying
function ThreadPinning.distributed_getcpuids(; include_master = false)
res = Dict{Int, Vector{Int}}()
for w in getworkerpids(; include_master)
res[w] = Distributed.@fetchfrom w ThreadPinning.getcpuids()
end
return res
end

function ThreadPinning.distributed_gethostnames(; include_master = false)
res = Dict{Int, String}()
for w in getworkerpids(; include_master)
res[w] = Distributed.@fetchfrom w gethostname()
end
return res
end

function ThreadPinning.distributed_getispinned(; include_master = false)
res = Dict{Int, Vector{Bool}}()
for w in getworkerpids(; include_master)
res[w] = Distributed.@fetchfrom w ThreadPinning.getispinned()
end
return res
end

function compute_distributed_topology(hostnames_dict)
dist_topo = Vector{@NamedTuple{
pid::Int64, localid::Int64, node::Int64, nodename::String}}(
undef, length(hostnames_dict))
sorted_by_pid = sortperm(collect(keys(hostnames_dict)))
nodes = unique(collect(values(hostnames_dict))[sorted_by_pid])
idx = 1
for (inode, node) in enumerate(nodes)
workers_onnode = collect(keys(filter(p -> p[2] == node, hostnames_dict)))
sort!(workers_onnode) # on each node we sort by worker pid
for (i, r) in enumerate(workers_onnode)
dist_topo[idx] = (; pid = r, localid = i - 1, node = inode, nodename = node)
idx += 1
end
end
return dist_topo
end

function ThreadPinning.distributed_topology(; include_master = false)
hostnames_dict = ThreadPinning.distributed_gethostnames(; include_master)
dist_topo = compute_distributed_topology(hostnames_dict)
return dist_topo
end
108 changes: 3 additions & 105 deletions ext/MPIExt/MPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,111 +3,9 @@ module MPIExt
import ThreadPinning: ThreadPinning
using MPI: MPI

# querying
function ThreadPinning.mpi_getcpuids(; comm = MPI.COMM_WORLD, dest = 0)
rank = MPI.Comm_rank(comm)
cpuids_ranks = MPI.Gather(ThreadPinning.getcpuids(), comm; root = dest)
rank != 0 && return
n_per_rank = Threads.nthreads()
return Dict((k - 1) => collect(v)
for (k, v) in enumerate(Iterators.partition(cpuids_ranks, n_per_rank)))
end

function ThreadPinning.mpi_gethostnames(; comm = MPI.COMM_WORLD, dest = 0)
rank = MPI.Comm_rank(comm)
hostnames_ranks = MPI.gather(gethostname(), comm; root = dest)
rank != 0 && return
return Dict((k - 1) => only(v)
for (k, v) in enumerate(Iterators.partition(hostnames_ranks, 1)))
end

# function compute_localranks(hostnames_ranks)
# localranks = fill(-1, length(hostnames_ranks))
# nodes = unique(values(hostnames_ranks))
# for n in nodes
# ranks_onnode = collect(keys(filter(p -> p[2] == n, hostnames_ranks)))
# sort!(ranks_onnode) # on each node we sort by rank id
# for (i, r) in enumerate(ranks_onnode)
# localranks[r + 1] = i - 1 # -1 because local ranks should start at 0
# end
# end
# return localranks
# end

function compute_mpi_topology(hostnames_ranks)
mpi_topo = Vector{@NamedTuple{
rank::Int64, localrank::Int64, node::Int64, nodename::String}}(
undef, length(hostnames_ranks))
sorted_by_rank = sortperm(collect(keys(hostnames_ranks)))
nodes = unique(collect(values(hostnames_ranks))[sorted_by_rank])
for (inode, node) in enumerate(nodes)
ranks_onnode = collect(keys(filter(p -> p[2] == node, hostnames_ranks)))
sort!(ranks_onnode) # on each node we sort by rank id
for (i, r) in enumerate(ranks_onnode)
mpi_topo[r + 1] = (; rank = r, localrank = i - 1, node = inode, nodename = node)
end
end
return mpi_topo
end

function ThreadPinning.mpi_topology(; comm = MPI.COMM_WORLD)
hostnames_ranks = ThreadPinning.mpi_gethostnames(; comm)
rank = MPI.Comm_rank(comm)
mpi_topo = rank == 0 ? compute_mpi_topology(hostnames_ranks) : nothing
# mpi_topo = MPI.bcast(mpi_topo, comm)
return mpi_topo
end

function ThreadPinning.mpi_getlocalrank(; comm = MPI.COMM_WORLD)
hostnames_ranks = ThreadPinning.mpi_gethostnames(; comm)
rank = MPI.Comm_rank(comm)
localranks = nothing
if rank == 0
mpi_topo = compute_mpi_topology(hostnames_ranks)
localranks = [r.localrank for r in mpi_topo]
end
localrank = MPI.scatter(localranks, comm; root = 0)
return localrank
end

# pinning
function ThreadPinning.mpi_pinthreads(symb::Symbol;
comm = MPI.COMM_WORLD,
compact = false,
nthreads_per_rank = Threads.nthreads(),
kwargs...)
if symb == :sockets
domain = ThreadPinning.socket
ndomain = ThreadPinning.nsockets
elseif symb == :numa
domain = ThreadPinning.numa
ndomain = ThreadPinning.nnuma
elseif symb == :cores
domain = ThreadPinning.core
ndomain = ThreadPinning.ncores
else
throw(ArgumentError("Invalid symbol. Supported symbols are :sockets, :numa, and :cores."))
end
localrank = ThreadPinning.mpi_getlocalrank(; comm)
cpuids = cpuids_of_localrank(localrank, domain, ndomain; nthreads_per_rank, compact)
ThreadPinning.pinthreads(cpuids; nthreads = nthreads_per_rank, kwargs...)
return
end

function cpuids_of_localrank(
localrank, domain, ndomain; nthreads_per_rank = Threads.nthreads(), compact = false)
i_in_domain, idomain = divrem(localrank, ndomain()) .+ 1
idcs = ((i_in_domain - 1) * nthreads_per_rank + 1):(i_in_domain * nthreads_per_rank)
if maximum(idcs) > length(domain(idomain))
@show maximum(idcs), length(domain(idomain))
error("Too many Julia threads / MPI ranks for the selected domain.")
end
if domain == ThreadPinning.core
cpuids = domain(idomain, idcs)
else
cpuids = domain(idomain, idcs; compact)
end
return cpuids
@static if Sys.islinux()
include("mpi_querying.jl")
include("mpi_pinning.jl")
end

end # module
Loading

0 comments on commit 130ac91

Please sign in to comment.