-
Notifications
You must be signed in to change notification settings - Fork 7
/
mpimanager.jl
519 lines (470 loc) · 18.5 KB
/
mpimanager.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
import Base: kill
import Sockets: connect, listenany, accept, IPv4, getsockname, getaddrinfo, wait_connected
################################################################################
# MPI Cluster Manager
# Note: The cluster manager object lives only in the manager process,
# except for MPI_TRANSPORT_ALL
# There are three different transport modes:
# MPI_ON_WORKERS: Use MPI between the workers only, not for the manager. This
# allows interactive use from a Julia shell, using the familiar `addprocs`
# interface.
# MPI_TRANSPORT_ALL: Use MPI on all processes; there is no separate manager
# process. This corresponds to the "usual" way in which MPI is used in a
# headless mode, e.g. submitted as a script to a queueing system.
# TCP_TRANSPORT_ALL: Same as MPI_TRANSPORT_ALL, but Julia uses TCP for its
# communication between processes. MPI can still be used by the user.
@enum TransportMode MPI_ON_WORKERS MPI_TRANSPORT_ALL TCP_TRANSPORT_ALL
mutable struct MPIManager <: ClusterManager
np::Int # number of worker processes (excluding the manager process)
mpi2j::Dict{Int,Int} # map MPI ranks to Julia processes
j2mpi::Dict{Int,Int} # map Julia to MPI ranks
mode::TransportMode
launched::Bool # Are the MPI processes running?
launch_timeout::Int # seconds
initialized::Bool # All workers registered with us
cond_initialized::Condition # notify this when all workers registered
# TCP Transport
port::UInt16
ip::UInt32
stdout_ios::Array
# MPI transport
rank2streams::Dict{Int,Tuple{IO,IO}} # map MPI ranks to (input,output) streams
ranks_left::Array{Int,1} # MPI ranks for which there is no stream pair yet
# MPI_TRANSPORT_ALL
comm::MPI.Comm
initiate_shutdown::Channel{Nothing}
sending_done::Channel{Nothing}
receiving_done::Channel{Nothing}
function MPIManager(; np::Integer = Sys.CPU_THREADS,
launch_timeout::Real = 60.0,
mode::TransportMode = MPI_ON_WORKERS,
master_tcp_interface::String="" )
mgr = new()
mgr.np = np
mgr.mpi2j = Dict{Int,Int}()
mgr.j2mpi = Dict{Int,Int}()
mgr.mode = mode
# Only start MPI processes for MPI_ON_WORKERS
mgr.launched = mode != MPI_ON_WORKERS
@assert MPI.Initialized() == mgr.launched
mgr.launch_timeout = launch_timeout
mgr.initialized = false
mgr.cond_initialized = Condition()
if np == 0
# Special case: no workers
mgr.initialized = true
if mgr.mode != MPI_ON_WORKERS
# Set up mapping for the manager
mgr.j2mpi[1] = 0
mgr.mpi2j[0] = 1
end
end
# Listen to TCP sockets if necessary
if mode != MPI_TRANSPORT_ALL
# Start a listener for capturing stdout from the workers
if master_tcp_interface != ""
# Listen on specified server interface
# This allows direct connection from other hosts on same network as
# specified interface.
port, server =
listenany(getaddrinfo(master_tcp_interface), 11000)
else
# Listen on default interface (localhost)
# This precludes direct connection from other hosts.
port, server = listenany(11000)
end
ip = getsockname(server)[1].host
@async begin
while true
sock = accept(server)
push!(mgr.stdout_ios, sock)
end
end
mgr.port = port
mgr.ip = ip
mgr.stdout_ios = IO[]
else
mgr.rank2streams = Dict{Int,Tuple{IO,IO}}()
size = MPI.Comm_size(MPI.COMM_WORLD)
mgr.ranks_left = collect(1:size-1)
end
if mode == MPI_TRANSPORT_ALL
mgr.sending_done = Channel{Nothing}(np)
mgr.receiving_done = Channel{Nothing}(1)
end
mgr.initiate_shutdown = Channel{Nothing}(1)
global initiate_shutdown = mgr.initiate_shutdown
return mgr
end
end
function Base.show(io::IO, mgr::MPIManager)
print(io, "MPI.MPIManager(np=$(mgr.np),launched=$(mgr.launched),mode=$(mgr.mode))")
end
################################################################################
# Cluster Manager functionality required by Base, mostly targeting the
# MPI_ON_WORKERS case
# Launch a new worker, called from Base.addprocs
function Distributed.launch(mgr::MPIManager, params::Dict,
instances::Array, cond::Condition)
try
if mgr.mode == MPI_ON_WORKERS
# Start the workers
if mgr.launched
println("Reuse of an MPIManager is not allowed.")
println("Try again with a different instance of MPIManager.")
throw(ErrorException("Reuse of MPIManager is not allowed."))
end
cookie = string(":cookie_",Distributed.cluster_cookie())
setup_cmds = `import MPIClusterManagers\;MPIClusterManagers.setup_worker'('$(mgr.ip),$(mgr.port),$cookie')'`
MPI.mpiexec() do mpiexec_cmd
mpi_cmd = `$mpiexec_cmd -n $(mgr.np) $(params[:exename]) -e $(Base.shell_escape(setup_cmds))`
open(detach(mpi_cmd))
end
mgr.launched = true
end
if mgr.mode != MPI_TRANSPORT_ALL
# Wait for the workers to connect back to the manager
t0 = time()
while (length(mgr.stdout_ios) < mgr.np &&
time() - t0 < mgr.launch_timeout)
sleep(1.0)
end
if length(mgr.stdout_ios) != mgr.np
error("Timeout -- the workers did not connect to the manager")
end
# Traverse all worker I/O streams and receive their MPI rank
configs = Array{WorkerConfig}(undef, mgr.np)
@sync begin
for io in mgr.stdout_ios
@async let io=io
config = WorkerConfig()
config.io = io
# Add config to the correct slot so that MPI ranks and
# Julia pids are in the same order
rank = Serialization.deserialize(io)
idx = mgr.mode == MPI_ON_WORKERS ? rank+1 : rank
configs[idx] = config
end
end
end
# Append our configs and notify the caller
append!(instances, configs)
notify(cond)
else
# This is a pure MPI configuration -- we don't need any bookkeeping
for cnt in 1:mgr.np
push!(instances, WorkerConfig())
end
notify(cond)
end
catch e
println("Error in MPI launch $e")
rethrow(e)
end
end
# Entry point for MPI worker processes for MPI_ON_WORKERS and TCP_TRANSPORT_ALL
setup_worker(host, port; kwargs...) = setup_worker(host, port, nothing; kwargs...)
function setup_worker(host, port, cookie; stdout_to_master=true, stderr_to_master=true)
!MPI.Initialized() && MPI.Init()
# Connect to the manager
io = connect(IPv4(host), port)
wait_connected(io)
stdout_to_master && redirect_stdout(io)
stderr_to_master && redirect_stderr(io)
# Send our MPI rank to the manager
rank = MPI.Comm_rank(MPI.COMM_WORLD)
Serialization.serialize(io, rank)
# Hand over control to Base
if cookie == nothing
Distributed.start_worker(io)
else
if isa(cookie, Symbol)
cookie = string(cookie)[8:end] # strip the leading "cookie_"
end
Distributed.start_worker(io, cookie)
end
end
# Manage a worker (e.g. register / deregister it)
function Distributed.manage(mgr::MPIManager, id::Integer, config::WorkerConfig, op::Symbol)
if op == :register
# Retrieve MPI rank from worker
# TODO: Why is this necessary? The workers already sent their rank.
rank = remotecall_fetch(()->MPI.Comm_rank(MPI.COMM_WORLD), id)
mgr.j2mpi[id] = rank
mgr.mpi2j[rank] = id
if length(mgr.j2mpi) == mgr.np
# All workers registered
mgr.initialized = true
notify(mgr.cond_initialized)
if mgr.mode != MPI_ON_WORKERS
# Set up mapping for the manager
mgr.j2mpi[1] = 0
mgr.mpi2j[0] = 1
end
end
elseif op == :deregister
info("pid=$(getpid()) id=$id op=$op")
# TODO: Sometimes -- very rarely -- Julia calls this `deregister`
# function, and then outputs a warning such as """error in running
# finalizer: ErrorException("no process with id 3 exists")""". These
# warnings seem harmless; still, we should find out what is going wrong
# here.
elseif op == :interrupt
# TODO: This should never happen if we rmprocs the workers properly
info("pid=$(getpid()) id=$id op=$op")
@assert false
elseif op == :finalize
# This is called from within a finalizer after deregistering; do nothing
else
info("pid=$(getpid()) id=$id op=$op")
@assert false # Unsupported operation
end
end
# Kill a worker
function kill(mgr::MPIManager, pid::Int, config::WorkerConfig)
# Exit the worker to avoid EOF errors on the workers
@spawnat pid begin
MPI.Finalize()
exit()
end
Distributed.set_worker_state(Distributed.Worker(pid), Distributed.W_TERMINATED)
end
# Set up a connection to a worker
function connect(mgr::MPIManager, pid::Int, config::WorkerConfig)
if mgr.mode != MPI_TRANSPORT_ALL
# Forward the call to the connect function in Base
return invoke(connect, Tuple{ClusterManager, Int, WorkerConfig},
mgr, pid, config)
end
rank = MPI.Comm_rank(mgr.comm)
if rank == 0
# Choose a rank for this worker
to_rank = pop!(mgr.ranks_left)
config.connect_at = to_rank
return start_send_event_loop(mgr, to_rank)
else
return start_send_event_loop(mgr, config.connect_at)
end
end
# Event loop for sending data to one other process, for the MPI_TRANSPORT_ALL
# case
function start_send_event_loop(mgr::MPIManager, rank::Integer)
try
r_s = Base.BufferStream()
w_s = Base.BufferStream()
mgr.rank2streams[rank] = (r_s, w_s)
# TODO: There is one task per communication partner -- this can be
# quite expensive when there are many workers. Design something better.
# For example, instead of maintaining two streams per worker, provide
# only abstract functions to write to / read from these streams.
@async begin
rr = MPI.Comm_rank(mgr.comm)
reqs = MPI.Request[]
while !isready(mgr.initiate_shutdown)
# When data are available, send them
while bytesavailable(w_s) > 0
data = take!(w_s.buffer)
push!(reqs, MPI.Isend(data, rank, 0, mgr.comm))
end
if !isempty(reqs)
(indices, stats) = MPI.Testsome!(reqs)
filter!(req -> req != MPI.REQUEST_NULL, reqs)
end
# TODO: Need a better way to integrate with libuv's event loop
yield()
end
put!(mgr.sending_done, nothing)
end
(r_s, w_s)
catch e
Base.show_backtrace(stdout, catch_backtrace())
println(e)
rethrow(e)
end
end
################################################################################
# Alternative startup model: All Julia processes are started via an external
# mpirun, and the user does not call addprocs.
# Enter the MPI cluster manager's main loop (does not return on the workers)
function start_main_loop(mode::TransportMode=TCP_TRANSPORT_ALL;
comm::MPI.Comm=MPI.COMM_WORLD,
stdout_to_master=true,
stderr_to_master=true)
!MPI.Initialized() && MPI.Init()
@assert MPI.Initialized() && !MPI.Finalized()
if mode == TCP_TRANSPORT_ALL
# Base is handling the workers and their event loop
# The workers have no manager object where to store the communicator.
# TODO: Use a global variable?
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
size = MPI.Comm_size(comm)
if rank == 0
# On the manager: Perform the usual steps
# Create manager object
mgr = MPIManager(np=size-1, mode=mode)
mgr.comm = comm
# Needed because of Julia commit https://github.com/JuliaLang/julia/commit/299300a409c35153a1fa235a05c3929726716600
if isdefined(Distributed, :init_multi)
Distributed.init_multi()
end
# Send connection information to all workers
# TODO: Use Bcast
for j in 1:size-1
cookie = Distributed.cluster_cookie()
MPI.send((mgr.ip, mgr.port, cookie), j, 0, comm)
end
# Tell Base about the workers
addprocs(mgr)
return mgr
else
# On a worker: Receive connection information
(obj, status) = MPI.recv(0, 0, comm)
(host, port, cookie) = obj
# Call the regular worker entry point
setup_worker(host, port, cookie, stdout_to_master=stdout_to_master, stderr_to_master=stderr_to_master) # does not return
end
elseif mode == MPI_TRANSPORT_ALL
comm = MPI.Comm_dup(comm)
rank = MPI.Comm_rank(comm)
size = MPI.Comm_size(comm)
# We are handling the workers and their event loops on our own
if rank == 0
# On the manager:
# Create manager object
mgr = MPIManager(np=size-1, mode=mode)
mgr.comm = comm
# Send the cookie over. Introduced in v"0.5.0-dev+4047". Irrelevant under MPI
# transport, but need it to satisfy the changed protocol.
Distributed.init_multi()
MPI.bcast(Distributed.cluster_cookie(), 0, comm)
# Start event loop for the workers
@async receive_event_loop(mgr)
# Tell Base about the workers
addprocs(mgr)
return mgr
else
# On a worker:
# Create a "fake" manager object since Base wants one
mgr = MPIManager(np=size-1, mode=mode)
mgr.comm = comm
# Recv the cookie
cookie = MPI.bcast(nothing, 0, comm)
Distributed.init_worker(cookie, mgr)
# Start a worker event loop
receive_event_loop(mgr)
if isdefined(MPI, :free) && hasmethod(MPI.free, Tuple{MPI.Comm})
MPI.free(comm)
end
MPI.Finalize()
exit()
end
else
error("Unknown mode $mode")
end
end
# Event loop for receiving data, for the MPI_TRANSPORT_ALL case
function receive_event_loop(mgr::MPIManager)
num_send_loops = 0
while !isready(mgr.initiate_shutdown)
(hasdata, stat) = MPI.Iprobe(MPI.MPI_ANY_SOURCE, 0, mgr.comm)
if hasdata
count = MPI.Get_count(stat, UInt8)
buf = Array{UInt8}(undef, count)
from_rank = MPI.Get_source(stat)
MPI.Recv!(buf, from_rank, 0, mgr.comm)
streams = get(mgr.rank2streams, from_rank, nothing)
if streams == nothing
# This is the first time we communicate with this rank.
# Set up a new connection.
(r_s, w_s) = start_send_event_loop(mgr, from_rank)
Distributed.process_messages(r_s, w_s)
num_send_loops += 1
else
(r_s, w_s) = streams
end
write(r_s, buf)
else
# TODO: Need a better way to integrate with libuv's event loop
yield()
end
end
for i in 1:num_send_loops
fetch(mgr.sending_done)
end
put!(mgr.receiving_done, nothing)
end
# Stop the main loop
# This function should be called by the main process only.
function stop_main_loop(mgr::MPIManager)
if mgr.mode == TCP_TRANSPORT_ALL
# Shut down all workers
rmprocs(workers())
# Poor man's flush of the send queue
sleep(1)
put!(mgr.initiate_shutdown, nothing)
MPI.Finalize()
elseif mgr.mode == MPI_TRANSPORT_ALL
# Shut down all workers, but not ourselves yet
for i in workers()
if i != myid()
@spawnat i begin
global initiate_shutdown
put!(initiate_shutdown, nothing)
end
end
end
# Poor man's flush of the send queue
sleep(1)
# Shut down ourselves
put!(mgr.initiate_shutdown, nothing)
wait(mgr.receiving_done)
MPI.Finalize()
else
@assert false
end
end
################################################################################
# MPI-specific communication methods
# Execute a command on all MPI ranks
# This uses MPI as communication method even if @everywhere uses TCP
function mpi_do(mgr::MPIManager, expr)
!mgr.initialized && wait(mgr.cond_initialized)
jpids = keys(mgr.j2mpi)
refs = Array{Any}(undef, length(jpids))
for (i,p) in enumerate(Iterators.filter(x -> x != myid(), jpids))
refs[i] = remotecall(expr, p)
end
# Execution on local process should be last, since it can block the main
# event loop
if myid() in jpids
refs[end] = remotecall(expr, myid())
end
# Retrieve remote exceptions if any
@sync begin
for r in refs
@async begin
resp = remotecall_fetch(r.where, r) do rr
wrkr_result = rr[]
# Only return result if it is an exception, i.e. don't
# return a valid result of a worker computation. This is
# a mpi_do and not mpi_callfetch.
isa(wrkr_result, Exception) ? wrkr_result : nothing
end
isa(resp, Exception) && throw(resp)
end
end
end
nothing
end
macro mpi_do(mgr, expr)
quote
# Evaluate expression in Main module
thunk = () -> (Core.eval(Main, $(Expr(:quote, expr))); nothing)
mpi_do($(esc(mgr)), thunk)
end
end
# All managed Julia processes
Distributed.procs(mgr::MPIManager) = sort(collect(keys(mgr.j2mpi)))
# All managed MPI ranks
mpiprocs(mgr::MPIManager) = sort(collect(keys(mgr.mpi2j)))