Skip to content

Commit

Permalink
Merge pull request #246 from vincelhx/branch_dev_vinc
Browse files Browse the repository at this point in the history
homogeneous code for the denoising of the three sensors
  • Loading branch information
agrouaze authored Nov 26, 2024
2 parents 168a3b8 + 330a0da commit 41b613d
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 52 deletions.
53 changes: 38 additions & 15 deletions src/xsar/radarsat2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
logger.addHandler(logging.NullHandler())

# we know tiff as no geotransform : ignore warning
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)
warnings.filterwarnings(
"ignore", category=rasterio.errors.NotGeoreferencedWarning)

# allow nan without warnings
# some dask warnings are still non filtered: https://github.com/dask/dask/issues/3245
Expand Down Expand Up @@ -92,7 +93,8 @@ def __init__(
# assert isinstance(sar_meta.coords2ll(100, 100),tuple)
else:
# we want self.sar_meta to be a dask actor on a worker
self.sar_meta = BlockingActorProxy(RadarSat2Meta.from_dict, dataset_id.dict)
self.sar_meta = BlockingActorProxy(
RadarSat2Meta.from_dict, dataset_id.dict)
del dataset_id
if self.sar_meta.multidataset:
raise IndexError(
Expand Down Expand Up @@ -162,7 +164,8 @@ def __init__(
"gamma0": "lutGamma",
"beta0": "lutBeta",
}
geoloc_vars = ["latitude", "longitude", "altitude", "incidence", "elevation"]
geoloc_vars = ["latitude", "longitude",
"altitude", "incidence", "elevation"]
for vv in skip_variables:
if vv in geoloc_vars:
geoloc_vars.remove(vv)
Expand Down Expand Up @@ -221,7 +224,8 @@ def __init__(
dask.array.empty_like(
self._dataset.digital_number.isel(pol=0).drop("pol"),
dtype=np.int8,
name="empty_var_tmpl-%s" % dask.base.tokenize(self.sar_meta.name),
name="empty_var_tmpl-%s" % dask.base.tokenize(
self.sar_meta.name),
),
dims=("line", "sample"),
coords={
Expand Down Expand Up @@ -270,9 +274,11 @@ def __init__(
self._dataset = xr.merge([self._dataset, rasters])
self._dataset = xr.merge([self.interpolate_times, self._dataset])
if "ground_heading" not in skip_variables:
self._dataset = xr.merge([self.load_ground_heading(), self._dataset])
self._dataset = xr.merge(
[self.load_ground_heading(), self._dataset])
if "velocity" not in skip_variables:
self._dataset = xr.merge([self.get_sensor_velocity(), self._dataset])
self._dataset = xr.merge(
[self.get_sensor_velocity(), self._dataset])
self._rasterized_masks = self.load_rasterized_masks()
self._dataset = xr.merge([self._rasterized_masks, self._dataset])
"""a = self._dataset.copy()
Expand Down Expand Up @@ -399,7 +405,8 @@ def load_from_geoloc(self, varnames, lazy_loading=True):
)
typee = self.sar_meta.geoloc[varname_in_geoloc].dtype
if lazy_loading:
da_var = map_blocks_coords(self._da_tmpl.astype(typee), interp_func)
da_var = map_blocks_coords(
self._da_tmpl.astype(typee), interp_func)
else:
da_val = interp_func(
self._dataset.digital_number.line,
Expand Down Expand Up @@ -471,7 +478,8 @@ def _resample_lut_values(self, lut):
da_var = xr.DataArray(data=var, dims=['line', 'sample'],
coords={'line': self._dataset.digital_number.line,
'sample': self._dataset.digital_number.sample})"""
da_var = map_blocks_coords(self._da_tmpl.astype(lut.dtype), interp_func)
da_var = map_blocks_coords(
self._da_tmpl.astype(lut.dtype), interp_func)
return da_var

@timing
Expand Down Expand Up @@ -510,7 +518,8 @@ def _get_lut_noise(self, var_name):
try:
lut_name = self._map_var_lut_noise[var_name]
except KeyError:
raise ValueError("can't find noise lut name for var '%s'" % var_name)
raise ValueError(
"can't find noise lut name for var '%s'" % var_name)
try:
lut = self.sar_meta.dt["radarParameters"][lut_name]
except KeyError:
Expand Down Expand Up @@ -546,7 +555,8 @@ def _interpolate_for_noise_lut(self, var_name):
noise_values = 10 ** (initial_lut / 10)
lines = np.arange(self.sar_meta.geoloc.line[-1] + 1)
noise_values_2d = np.tile(noise_values, (lines.shape[0], 1))
indexes = [first_pix + step * i for i in range(0, noise_values.shape[0])]
indexes = [first_pix + step *
i for i in range(0, noise_values.shape[0])]
interp_func = dask.delayed(RectBivariateSpline)(
x=lines, y=indexes, z=noise_values_2d, kx=1, ky=1
)
Expand Down Expand Up @@ -604,6 +614,18 @@ def apply_calibration_and_denoising(self):
% (var_name, lut_name)
)
self._dataset = self._add_denoised(self._dataset)

for var_name, lut_name in self._map_var_lut.items():
var_name_raw = var_name + "_raw"
if var_name_raw in self._dataset:
self._dataset[var_name_raw] = self._dataset[var_name_raw].where(
self._dataset[var_name_raw] > 0, 0)
else:
logger.debug(
"Skipping variable '%s' ('%s' lut is missing)"
% (var_name, lut_name)
)

self.datatree["measurement"] = self.datatree["measurement"].assign(
self._dataset
)
Expand Down Expand Up @@ -666,8 +688,6 @@ def _apply_calibration_lut(self, var_name):
# if self.resolution is not None:
lut = self._resample_lut_values(lut)
res = ((self._dataset.digital_number**2.0) + offset) / lut
# Have to know if we keep this line written by Olivier because it replaces 0 values by nan --> creates problems for wind inversion
res = res.where(res > 0)
res.attrs.update(lut.attrs)
return res.to_dataset(name=var_name + "_raw")

Expand Down Expand Up @@ -745,7 +765,8 @@ def interpolate_times(self):
interp_func = RectBivariateSpline(
x=lines, y=samples, z=time_values_2d.astype(float), kx=1, ky=1
)
da_var = map_blocks_coords(self._da_tmpl.astype("datetime64[ns]"), interp_func)
da_var = map_blocks_coords(
self._da_tmpl.astype("datetime64[ns]"), interp_func)
return da_var.isel(sample=0).to_dataset(name="time")

def get_sensor_velocity(self):
Expand Down Expand Up @@ -782,7 +803,8 @@ def get_sensor_velocity(self):
vels = np.sqrt(np.sum(velos, axis=0))
interp_f = interp1d(azimuth_times.astype(float), vels)
_vels = interp_f(interp_times.astype(float))
res = xr.DataArray(_vels, dims=["line"], coords={"line": self.dataset.line})
res = xr.DataArray(_vels, dims=["line"], coords={
"line": self.dataset.line})
return xr.Dataset({"velocity": res})

def _reconfigure_reader_datatree(self):
Expand Down Expand Up @@ -837,7 +859,8 @@ def get_list_keys_delete(dt, list_keys, inside=True):
new_dt["lut"] = dt["lut"].ds.rename(rename_lut)

# extract noise_lut, rename and put these in a dataset
new_dt["noise_lut"] = dt["radarParameters"].ds.rename(rename_radarParameters)
new_dt["noise_lut"] = dt["radarParameters"].ds.rename(
rename_radarParameters)
new_dt["noise_lut"].attrs = {} # reset attributes
delete_list = get_list_keys_delete(
new_dt["noise_lut"], rename_radarParameters.values(), inside=False
Expand Down
Loading

0 comments on commit 41b613d

Please sign in to comment.