-
Notifications
You must be signed in to change notification settings - Fork 65
/
mpi.jl
73 lines (58 loc) · 1.73 KB
/
mpi.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# EXCLUDE FROM TESTING
using KernelAbstractions
using MPI
# TODO: Implement in MPI.jl
function cooperative_test!(req)
done = false
while !done
done, _ = MPI.Test(req, MPI.Status)
yield()
end
end
function cooperative_wait(task::Task)
while !Base.istaskdone(task)
MPI.Iprobe(MPI.MPI_ANY_SOURCE, MPI.MPI_ANY_TAG, MPI.COMM_WORLD)
yield()
end
wait(task)
end
function exchange!(h_send_buf, d_recv_buf, h_recv_buf, src_rank, dst_rank, comm)
recv_req = MPI.Irecv!(h_recv_buf, src_rank, 666, comm)
recv = Base.Threads.@spawn begin
KernelAbstractions.priority!(backend, :high)
cooperative_test!(recv_req)
KernelAbstractions.copyto!(backend, d_recv_buf, h_recv_buf)
KernelAbstractions.synchronize(backend) # Gurantueed to be cooperative
end
send = Base.Threads.@spawn begin
send_req = MPI.Isend!(h_send_buf, dst_rank, 666, comm)
cooperative_test!(send_req)
end
return recv, send
end
function main(backend)
if !MPI.Initialized()
MPI.Init()
end
comm = MPI.COMM_WORLD
MPI.Barrier(comm)
dst_rank = mod(MPI.Comm_rank(comm) + 1, MPI.Comm_size(comm))
src_rank = mod(MPI.Comm_rank(comm) - 1, MPI.Comm_size(comm))
T = Int64
M = 10
d_recv_buf = allocate(backend, T, M)
fill!(d_recv_buf, -1)
h_send_buf = zeros(T, M)
h_recv_buf = zeros(T, M)
fill!(h_send_buf, MPI.Comm_rank(comm))
fill!(h_recv_buf, -1)
KernelAbstractions.synchronize(backend)
recv_task, send_task = exchange!(
h_send_buf, d_recv_buf, h_recv_buf,
src_rank, dst_rank, comm,
)
cooperative_wait(recv_task)
cooperative_wait(send_task)
@test all(d_recv_buf .== src_rank)
end
main(backend)