-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy path_multimodel.py
456 lines (385 loc) · 15.3 KB
/
_multimodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
"""multimodel statistics.
Functions for multi-model operations
supports a multitude of multimodel statistics
computations; the only requisite is the ingested
cubes have (TIME-LAT-LON) or (TIME-PLEV-LAT-LON)
dimensions; and obviously consistent units.
It operates on different (time) spans:
- full: computes stats on full dataset time;
- overlap: computes common time overlap between datasets;
"""
import logging
import re
from datetime import datetime
from functools import partial, reduce
import cf_units
import iris
import numpy as np
import scipy
from ._time import regrid_time
logger = logging.getLogger(__name__)
def _get_time_offset(time_unit):
"""Return a datetime object equivalent to tunit."""
# tunit e.g. 'day since 1950-01-01 00:00:00.0000000 UTC'
cfunit = cf_units.Unit(time_unit, calendar=cf_units.CALENDAR_STANDARD)
time_offset = cfunit.num2date(0)
return time_offset
def _plev_fix(dataset, pl_idx):
"""Extract valid plev data.
this function takes care of situations
in which certain plevs are completely
masked due to unavailable interpolation
boundaries.
"""
if np.ma.is_masked(dataset):
# keep only the valid plevs
if not np.all(dataset.mask[pl_idx]):
statj = np.ma.array(dataset[pl_idx], mask=dataset.mask[pl_idx])
else:
logger.debug('All vals in plev are masked, ignoring.')
statj = None
else:
mask = np.zeros_like(dataset[pl_idx], bool)
statj = np.ma.array(dataset[pl_idx], mask=mask)
return statj
def _quantile(data, axis, quantile):
"""Calculate quantile.
Workaround for calling scipy's mquantiles with arrays of >2 dimensions
Similar to iris' _percentiles function, see their discussion:
https://github.com/SciTools/iris/pull/625
"""
# Ensure that the target axis is the last dimension.
data = np.rollaxis(data, axis, start=data.ndim)
shape = data.shape[:-1]
# Flatten any leading dimensions.
if shape:
data = data.reshape([np.prod(shape), data.shape[-1]])
# Perform the quantile calculation.
result = scipy.stats.mstats.mquantiles(data,
quantile,
axis=-1,
alphap=1,
betap=1)
# Ensure to unflatten any leading dimensions.
if shape:
result = result.reshape(shape)
# Check whether to reduce to a scalar result
if result.shape == (1, ):
result = result[0]
return result
def _compute_statistic(data, statistic_name):
"""Compute multimodel statistic."""
data = np.ma.array(data)
statistic = data[0]
if statistic_name == 'median':
statistic_function = np.ma.median
elif statistic_name == 'mean':
statistic_function = np.ma.mean
elif statistic_name == 'std':
statistic_function = np.ma.std
elif statistic_name == 'max':
statistic_function = np.ma.max
elif statistic_name == 'min':
statistic_function = np.ma.min
elif re.match(r"^(p\d{1,2})(\.\d*)?$", statistic_name):
# percentiles between p0 and p99.99999...
quantile = float(statistic_name[1:]) / 100
statistic_function = partial(_quantile, quantile=quantile)
else:
raise ValueError(f'No such statistic: `{statistic_name}`')
# no plevs
if len(data[0].shape) < 3:
# get all NOT fully masked data - u_data
# data is per time point
# so we can safely NOT compute stats for single points
if data.ndim == 1:
u_datas = [d for d in data]
else:
u_datas = [d for d in data if not np.all(d.mask)]
if len(u_datas) > 1:
statistic = statistic_function(data, axis=0)
else:
statistic.mask = True
return statistic
# plevs
for j in range(statistic.shape[0]):
plev_check = []
for cdata in data:
fixed_data = _plev_fix(cdata, j)
if fixed_data is not None:
plev_check.append(fixed_data)
# check for nr datasets
if len(plev_check) > 1:
plev_check = np.ma.array(plev_check)
statistic[j] = statistic_function(plev_check, axis=0)
else:
statistic.mask[j] = True
return statistic
def _put_in_cube(template_cube, cube_data, statistic, t_axis):
"""Quick cube building and saving."""
if t_axis is None:
times = template_cube.coord('time')
else:
unit_name = template_cube.coord('time').units.name
tunits = cf_units.Unit(unit_name, calendar="standard")
times = iris.coords.DimCoord(t_axis,
standard_name='time',
units=tunits,
var_name='time')
coord_names = [c.long_name for c in template_cube.coords()]
coord_names.extend([c.standard_name for c in template_cube.coords()])
if 'latitude' in coord_names:
lats = template_cube.coord('latitude')
else:
lats = None
if 'longitude' in coord_names:
lons = template_cube.coord('longitude')
else:
lons = None
# no plevs
if len(template_cube.shape) == 3:
cspec = [(times, 0), (lats, 1), (lons, 2)]
# plevs
elif len(template_cube.shape) == 4:
plev = template_cube.coord('air_pressure')
cspec = [(times, 0), (plev, 1), (lats, 2), (lons, 3)]
elif len(template_cube.shape) == 1:
cspec = [
(times, 0),
]
elif len(template_cube.shape) == 2:
# If you're going to hardwire air_pressure into this,
# might as well have depth here too.
plev = template_cube.coord('depth')
cspec = [
(times, 0),
(plev, 1),
]
# correct dspec if necessary
fixed_dspec = np.ma.fix_invalid(cube_data, copy=False, fill_value=1e+20)
# put in cube
stats_cube = iris.cube.Cube(fixed_dspec,
dim_coords_and_dims=cspec,
long_name=statistic)
coord_names = [coord.name() for coord in template_cube.coords()]
if 'air_pressure' in coord_names:
if len(template_cube.shape) == 3:
stats_cube.add_aux_coord(template_cube.coord('air_pressure'))
stats_cube.var_name = template_cube.var_name
stats_cube.long_name = template_cube.long_name
stats_cube.standard_name = template_cube.standard_name
stats_cube.units = template_cube.units
return stats_cube
def _datetime_to_int_days(cube):
"""Return list of int(days) converted from cube datetime cells."""
cube = _align_yearly_axes(cube)
time_cells = [cell.point for cell in cube.coord('time').cells()]
# extract date info
real_dates = []
for date_obj in time_cells:
# real_date resets the actual data point day
# to the 1st of the month so that there are no
# wrong overlap indices
real_date = datetime(date_obj.year, date_obj.month, 1, 0, 0, 0)
real_dates.append(real_date)
# get the number of days starting from the reference unit
time_unit = cube.coord('time').units.name
time_offset = _get_time_offset(time_unit)
days = [(date_obj - time_offset).days for date_obj in real_dates]
return days
def _align_yearly_axes(cube):
"""Perform a time-regridding operation to align time axes for yr data."""
years = [cell.point.year for cell in cube.coord('time').cells()]
# be extra sure that the first point is not in the previous year
if 0 not in np.diff(years):
return regrid_time(cube, 'yr')
return cube
def _get_overlap(cubes):
"""
Get discrete time overlaps.
This method gets the bounds of coord time
from the cube and assembles a continuous time
axis with smallest unit 1; then it finds the
overlaps by doing a 1-dim intersect;
takes the floor of first date and
ceil of last date.
"""
all_times = []
for cube in cubes:
span = _datetime_to_int_days(cube)
start, stop = span[0], span[-1]
all_times.append([start, stop])
bounds = [range(b[0], b[-1] + 1) for b in all_times]
time_pts = reduce(np.intersect1d, bounds)
if len(time_pts) > 1:
time_bounds_list = [time_pts[0], time_pts[-1]]
return time_bounds_list
def _slice_cube(cube, t_1, t_2):
"""
Efficient slicer.
Simple cube data slicer on indices
of common time-data elements.
"""
time_pts = [t for t in cube.coord('time').points]
converted_t = _datetime_to_int_days(cube)
idxs = sorted([
time_pts.index(ii) for ii, jj in zip(time_pts, converted_t)
if t_1 <= jj <= t_2
])
return [idxs[0], idxs[-1]]
def _monthly_t(cubes):
"""Rearrange time points for monthly data."""
# get original cubes tpoints
days = {day for cube in cubes for day in _datetime_to_int_days(cube)}
return sorted(days)
def _full_time_slice(cubes, ndat, indices, ndatarr, t_idx):
"""Construct a contiguous collection over time."""
for idx_cube, cube in enumerate(cubes):
# reset mask
ndat.mask = True
ndat[indices[idx_cube]] = cube.data
if np.ma.is_masked(cube.data):
ndat.mask[indices[idx_cube]] = cube.data.mask
else:
ndat.mask[indices[idx_cube]] = False
ndatarr[idx_cube] = ndat[t_idx]
# return time slice
return ndatarr
def _assemble_overlap_data(cubes, interval, statistic):
"""Get statistical data in iris cubes for OVERLAP."""
start, stop = interval
sl_1, sl_2 = _slice_cube(cubes[0], start, stop)
stats_dats = np.ma.zeros(cubes[0].data[sl_1:sl_2 + 1].shape)
# keep this outside the following loop
# this speeds up the code by a factor of 15
indices = [_slice_cube(cube, start, stop) for cube in cubes]
for i in range(stats_dats.shape[0]):
time_data = [
cube.data[indx[0]:indx[1] + 1][i]
for cube, indx in zip(cubes, indices)
]
stats_dats[i] = _compute_statistic(time_data, statistic)
stats_cube = _put_in_cube(cubes[0][sl_1:sl_2 + 1],
stats_dats,
statistic,
t_axis=None)
return stats_cube
def _assemble_full_data(cubes, statistic):
"""Get statistical data in iris cubes for FULL."""
# all times, new MONTHLY data time axis
time_axis = [float(fl) for fl in _monthly_t(cubes)]
# new big time-slice array shape
new_shape = [len(time_axis)] + list(cubes[0].shape[1:])
# assemble an array to hold all time data
# for all cubes; shape is (ncubes,(plev), lat, lon)
new_arr = np.ma.empty([len(cubes)] + list(new_shape[1:]))
# data array for stats computation
stats_dats = np.ma.zeros(new_shape)
# assemble indices list to chop new_arr on
indices_list = []
# empty data array to hold time slices
empty_arr = np.ma.empty(new_shape)
# loop through cubes and populate empty_arr with points
for cube in cubes:
time_redone = _datetime_to_int_days(cube)
oidx = [time_axis.index(s) for s in time_redone]
indices_list.append(oidx)
for i in range(new_shape[0]):
# hold time slices only
new_datas_array = _full_time_slice(cubes, empty_arr, indices_list,
new_arr, i)
# list to hold time slices
time_data = []
for j in range(len(cubes)):
time_data.append(new_datas_array[j])
stats_dats[i] = _compute_statistic(time_data, statistic)
stats_cube = _put_in_cube(cubes[0], stats_dats, statistic, time_axis)
return stats_cube
def multi_model_statistics(products, span, statistics, output_products=None):
"""Compute multi-model statistics.
Multimodel statistics computed along the time axis. Can be
computed across a common overlap in time (set span: overlap)
or across the full length in time of each model (set span: full).
Restrictive computation is also available by excluding any set of
models that the user will not want to include in the statistics
(set exclude: [excluded models list]).
Restrictions needed by the input data:
- model datasets must have consistent shapes,
- higher dimensional data is not supported (ie dims higher than four:
time, vertical axis, two horizontal axes).
Parameters
----------
products: list
list of data products or cubes to be used in multimodel stat
computation;
cube attribute of product is the data cube for computing the stats.
span: str
overlap or full; if overlap, statitsticss are computed on common time-
span; if full, statistics are computed on full time spans, ignoring
missing data.
output_products: dict
dictionary of output products. MUST be specified if products are NOT
cubes
statistics: list of str
list of statistical measure(s) to be computed. Available options:
mean, median, max, min, std, or pXX.YY (for percentile XX.YY; decimal
part optional).
Returns
-------
set or dict or list
`set` of data products if `output_products` is given
`dict` of cubes if `output_products` is not given
`list` of input cubes if there is no overlap between cubes when
using `span='overlap'`
Raises
------
ValueError
If span is neither overlap nor full.
"""
logger.debug('Multimodel statistics: computing: %s', statistics)
if len(products) < 2:
logger.info("Single dataset in list: will not compute statistics.")
return products
if output_products:
cubes = [cube for product in products for cube in product.cubes]
statistic_products = set()
else:
cubes = products
statistic_products = {}
if span == 'overlap':
# check if we have any time overlap
interval = _get_overlap(cubes)
if interval is None:
logger.info("Time overlap between cubes is none or a single point."
"check datasets: will not compute statistics.")
return products
logger.debug("Using common time overlap between "
"datasets to compute statistics.")
elif span == 'full':
logger.debug("Using full time spans to compute statistics.")
else:
raise ValueError(
"Unexpected value for span {}, choose from 'overlap', 'full'".
format(span))
for statistic in statistics:
# Compute statistic
if span == 'overlap':
statistic_cube = _assemble_overlap_data(cubes, interval, statistic)
elif span == 'full':
statistic_cube = _assemble_full_data(cubes, statistic)
statistic_cube.data = np.ma.array(statistic_cube.data,
dtype=np.dtype('float32'))
if output_products:
# Add to output product and log provenance
statistic_product = output_products[statistic]
statistic_product.cubes = [statistic_cube]
for product in products:
statistic_product.wasderivedfrom(product)
logger.info("Generated %s", statistic_product)
statistic_products.add(statistic_product)
else:
statistic_products[statistic] = statistic_cube
if output_products:
products |= statistic_products
return products
return statistic_products