-
Notifications
You must be signed in to change notification settings - Fork 0
/
pm_utils.py
65 lines (55 loc) · 2.25 KB
/
pm_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import print_function, division
from mpi4py import MPI
import numpy as np
from pmesh.pm import RealField, ComplexField
def ltoc(field, index):
"""
Convert local to collective index, inverting pm.pmesh.Field._ctol.
Index must be a single tuple with length field.ndim
"""
assert isinstance(field, RealField)
assert np.array(index).shape == (field.ndim,)
return tuple(list(index + field.start))
def ltoc_index_arr(field, lindex_arr):
"""
Convert local to collective index, inverting pm.pmesh.Field._ctol.
lindex_arr is an array of local indices, with last axis labeling
the dimension of the field.
Example: If we want to convert N indices of a 3-D field, have
lindex_arr.shape=(N,3).
"""
assert isinstance(field, RealField)
assert type(lindex_arr) == np.ndarray
assert np.all(lindex_arr >= 0)
assert lindex_arr.shape[-1] == field.ndim
cindex_arr = lindex_arr + field.start
return cindex_arr
def cgetitem_index_arr(field, cindex_arr):
"""
Get values of field for an array of collective cindices (ranging from
0 to Ngrid-1 in each dimension).
This is a vector version of pmesh.pm.Field.cgetitem.
Do this by getting field[cindex_arr-field.start] but only if
cindex_arr item is between field.start and field.start+field.shape,
which depends on the MPI rank.
Do this essentially by running
if all(index1 >= self.start) and all(index1 < self.start + self.shape):
return field[index1 - self.start]
else:
return 0
Then do allreduce to get field value across all ranks.
"""
assert isinstance(field, RealField)
assert type(cindex_arr) == np.ndarray
assert np.all(cindex_arr >= 0)
assert cindex_arr.shape[-1] == field.ndim
assert field.ndim == 3
value_arr = np.zeros(cindex_arr.shape[:-1], dtype=field.value.dtype)
www = np.where(
np.all(cindex_arr >= field.start, axis=-1) &
np.all(cindex_arr < (field.start + field.shape), axis=-1))[0]
lindex_wanted = (cindex_arr[www, :] - field.start)
value_arr[www] = field[lindex_wanted[:, 0], lindex_wanted[:, 1],
lindex_wanted[:, 2]]
value_arr = field.pm.comm.allreduce(value_arr, op=MPI.SUM)
return value_arr