diff --git a/src/data.jl b/src/data.jl index 45c1895..c77499f 100644 --- a/src/data.jl +++ b/src/data.jl @@ -129,6 +129,7 @@ mutable struct NCData{T,N #=,TA=#} data_full::Array{T,4} missingmask::BitArray{3} meandata::Array{T,3} + mask::BitMatrix x::Array{T,5} isoutput::Vector{Bool} train::Bool @@ -211,7 +212,9 @@ export sizey dd = NCData(lon,lat,time,data_full,missingmask,ndims; train = false, obs_err_std = 1., - jitter_std = 0.05) + jitter_std = 0.05, + mask = trues(size(data_full)[1:2]), +) Return a structure holding the data for training (`train = true`) or testing (`train = false`) the neural network. `obs_err_std` is the error standard deviation of the @@ -231,6 +234,7 @@ function NCData(lon,lat,time,data_full,missingmask,ndims; cycle_periods = (365.25,), # days time_origin = DateTime(1970,1,1), remove_mean = true, + mask = trues(size(data_full)[1:2]), direction_obs = nothing, # auxdata = (), ) @@ -302,39 +306,47 @@ function NCData(lon,lat,time,data_full,missingmask,ndims; N = (is3D ? 4 : 3) - NCData{Float32,N}(Float32.(lon),Float32.(lat),time,data_full,missingmask,meandata[:,:,:,1],x, - isoutput, - train, - Float32.(obs_err_std), - Float32.(jitter_std), - lon_scaled, - lat_scaled, - time_cos, - time_sin, - ntime_win, -# auxdata, - direction_obs_, - output_ndims, - ndims, - ) + NCData{Float32,N}( + Float32.(lon),Float32.(lat), + time, + data_full, + missingmask, + meandata[:,:,:,1], + mask, + x, + isoutput, + train, + Float32.(obs_err_std), + Float32.(jitter_std), + lon_scaled, + lat_scaled, + time_cos, + time_sin, + ntime_win, + # auxdata, + direction_obs_, + output_ndims, + ndims, + ) end getp(x,sym,default) = (hasproperty(x, sym) ? getproperty(x,sym) : default) function NCData(data; kwargs...) - lon,lat,datatime,data_full,missingmask,mask = DINCAE.load_gridded_nc(data) + lon,lat,datatime,data_full,missingmask,mask = load_gridded_nc(data) default_jitter_std = 0.05 jitter_std = [getp(d,:jitter_std,default_jitter_std) for d in data] ndims = [getp(d,:ndims,1) for d in data] - return DINCAE.NCData(lon,lat,datatime,data_full,missingmask,ndims; - obs_err_std = [d.obs_err_std for d in data], - jitter_std = jitter_std, - isoutput = [d.isoutput for d in data], - kwargs...) + return NCData(lon,lat,datatime,data_full,missingmask,ndims; + obs_err_std = [d.obs_err_std for d in data], + jitter_std = jitter_std, + isoutput = [d.isoutput for d in data], + mask = mask, + kwargs...) end @@ -556,7 +568,10 @@ function getobs!(dd::NCData,data,index::Int) return data end -function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1) +function savesample(ds,varnames,xrec,meandata,ii,offset; + output_ndims = 1, + mask = nothing) + fill_value = -9999. function accumulate!(var,index,slice,count) @@ -594,6 +609,7 @@ function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1) end end + # typically the batch size nmax = size(xrec,4) if output_ndims == 1 @@ -610,6 +626,11 @@ function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1) batch_sigma_rec[isnan.(recdata)] .= NaN for n in 1:nmax + if !isnothing(mask) + view(recdata,:,:,n)[.!mask] .= NaN + view(batch_sigma_rec,:,:,n)[.!mask] .= NaN + end + accumulate!(nc_batch_m_rec.var,n+offset,recdata[:,:,n],count) accumulate!(nc_batch_sigma_rec.var,n+offset,batch_sigma_rec[:,:,n],count) end diff --git a/src/model.jl b/src/model.jl index cee7690..26067a4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -328,6 +328,10 @@ at the epochs defined by `save_epochs`. application. +Internally the time mean is removed (per default) from the data before it is reconstructed. +The time mean is also added back when the file is saved. +However, the mean is undefined for for are pixels in the data defined as valid (sea) by the mask which do not have any valid data in the training dataset. + See `DINCAE.load_gridded_nc` for more information about the netCDF file. """ function reconstruct(Atype,data_all,fnames_rec; @@ -533,12 +537,13 @@ function reconstruct(Atype,data_all,fnames_rec; offset = (ii-1)*batch_size - DINCAE.savesample( + savesample( ds_, output_varnames,xrec, train_data.meandata[:,:,findall(train_data.isoutput)], ii-1,offset, output_ndims = output_ndims, + mask = train_data.mask, ) end end