Skip to content

Commit

Permalink
mask output even if remove_mean = false (issue #24)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Apr 18, 2024
1 parent bdc9c94 commit d2ea559
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
67 changes: 44 additions & 23 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ mutable struct NCData{T,N #=,TA=#}
data_full::Array{T,4}
missingmask::BitArray{3}
meandata::Array{T,3}
mask::BitMatrix
x::Array{T,5}
isoutput::Vector{Bool}
train::Bool
Expand Down Expand Up @@ -211,7 +212,9 @@ export sizey
dd = NCData(lon,lat,time,data_full,missingmask,ndims;
train = false,
obs_err_std = 1.,
jitter_std = 0.05)
jitter_std = 0.05,
mask = trues(size(data_full)[1:2]),
)
Return a structure holding the data for training (`train = true`) or testing (`train = false`)
the neural network. `obs_err_std` is the error standard deviation of the
Expand All @@ -231,6 +234,7 @@ function NCData(lon,lat,time,data_full,missingmask,ndims;
cycle_periods = (365.25,), # days
time_origin = DateTime(1970,1,1),
remove_mean = true,
mask = trues(size(data_full)[1:2]),
direction_obs = nothing,
# auxdata = (),
)
Expand Down Expand Up @@ -302,39 +306,47 @@ function NCData(lon,lat,time,data_full,missingmask,ndims;

N = (is3D ? 4 : 3)

NCData{Float32,N}(Float32.(lon),Float32.(lat),time,data_full,missingmask,meandata[:,:,:,1],x,
isoutput,
train,
Float32.(obs_err_std),
Float32.(jitter_std),
lon_scaled,
lat_scaled,
time_cos,
time_sin,
ntime_win,
# auxdata,
direction_obs_,
output_ndims,
ndims,
)
NCData{Float32,N}(
Float32.(lon),Float32.(lat),
time,
data_full,
missingmask,
meandata[:,:,:,1],
mask,
x,
isoutput,
train,
Float32.(obs_err_std),
Float32.(jitter_std),
lon_scaled,
lat_scaled,
time_cos,
time_sin,
ntime_win,
# auxdata,
direction_obs_,
output_ndims,
ndims,
)
end


getp(x,sym,default) = (hasproperty(x, sym) ? getproperty(x,sym) : default)

function NCData(data; kwargs...)
lon,lat,datatime,data_full,missingmask,mask = DINCAE.load_gridded_nc(data)
lon,lat,datatime,data_full,missingmask,mask = load_gridded_nc(data)

default_jitter_std = 0.05

jitter_std = [getp(d,:jitter_std,default_jitter_std) for d in data]
ndims = [getp(d,:ndims,1) for d in data]

return DINCAE.NCData(lon,lat,datatime,data_full,missingmask,ndims;
obs_err_std = [d.obs_err_std for d in data],
jitter_std = jitter_std,
isoutput = [d.isoutput for d in data],
kwargs...)
return NCData(lon,lat,datatime,data_full,missingmask,ndims;
obs_err_std = [d.obs_err_std for d in data],
jitter_std = jitter_std,
isoutput = [d.isoutput for d in data],
mask = mask,
kwargs...)

end

Expand Down Expand Up @@ -556,7 +568,10 @@ function getobs!(dd::NCData,data,index::Int)
return data
end

function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1)
function savesample(ds,varnames,xrec,meandata,ii,offset;
output_ndims = 1,
mask = nothing)

fill_value = -9999.

function accumulate!(var,index,slice,count)
Expand Down Expand Up @@ -594,6 +609,7 @@ function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1)
end
end

# typically the batch size
nmax = size(xrec,4)

if output_ndims == 1
Expand All @@ -610,6 +626,11 @@ function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1)
batch_sigma_rec[isnan.(recdata)] .= NaN

for n in 1:nmax
if !isnothing(mask)
view(recdata,:,:,n)[.!mask] .= NaN
view(batch_sigma_rec,:,:,n)[.!mask] .= NaN
end

accumulate!(nc_batch_m_rec.var,n+offset,recdata[:,:,n],count)
accumulate!(nc_batch_sigma_rec.var,n+offset,batch_sigma_rec[:,:,n],count)
end
Expand Down
7 changes: 6 additions & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ at the epochs defined by `save_epochs`.
application.
Internally the time mean is removed (per default) from the data before it is reconstructed.
The time mean is also added back when the file is saved.
However, the mean is undefined for for are pixels in the data defined as valid (sea) by the mask which do not have any valid data in the training dataset.
See `DINCAE.load_gridded_nc` for more information about the netCDF file.
"""
function reconstruct(Atype,data_all,fnames_rec;
Expand Down Expand Up @@ -533,12 +537,13 @@ function reconstruct(Atype,data_all,fnames_rec;

offset = (ii-1)*batch_size

DINCAE.savesample(
savesample(
ds_,
output_varnames,xrec,
train_data.meandata[:,:,findall(train_data.isoutput)],
ii-1,offset,
output_ndims = output_ndims,
mask = train_data.mask,
)
end
end
Expand Down

0 comments on commit d2ea559

Please sign in to comment.