Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update some example notebooks #12

Merged
merged 5 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions jaxtronomy/Analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scipy import ndimage
from matplotlib.colors import LogNorm
from matplotlib.colors import Normalize, LogNorm

from jaxtronomy.Util.plot_util import nice_colorbar, nice_colorbar_residuals
from jaxtronomy.Util import image_util

# Some general default for plotting
plt.rc('image', interpolation='none', origin='lower') # imshow
plt.rc('image', interpolation='none', origin='lower') # for imshow


class Plotter(object):
Expand All @@ -32,19 +32,20 @@ class Plotter(object):
cmap_flux = copy.copy(cmap_base)
cmap_flux.set_bad(color='black')
cmap_flux_alt = copy.copy(cmap_base)
cmap_flux_alt.set_bad(color='#222222') # emphasize e.g. non-positive pixels in log scale
cmap_flux_alt.set_bad(color='#222222') # to emphasize non-positive pixels in log scale
cmap_resid = plt.get_cmap('RdBu_r')
cmap_default = plt.get_cmap('viridis')
cmap_deriv1 = plt.get_cmap('cividis')
cmap_deriv2 = plt.get_cmap('inferno')

def __init__(self, base_fontsize=0.28, flux_log_scale=True,
flux_vmin=None, flux_vmax=None):
flux_vmin=None, flux_vmax=None, res_vmax=6):
self._base_fs = base_fontsize
if flux_log_scale is True:
self.norm_flux = LogNorm(flux_vmin, flux_vmax)
else:
self.norm_flux = None
self.norm_res = Normalize(-res_vmax, res_vmax)

def set_data(self, data):
self._data = data
Expand Down Expand Up @@ -167,12 +168,12 @@ def model_summary(self, lens_image, kwargs_result,
nice_colorbar(im, position='top', pad=0.4, size=0.2,
colorbar_kwargs={'orientation': 'horizontal'})
ax = axes[i_row, 2]
norm_res = lens_image.normalized_residuals(data, model, mask=likelihood_mask)
residuals = lens_image.normalized_residuals(data, model, mask=likelihood_mask)
red_chi2 = lens_image.reduced_chi2(data, model, mask=likelihood_mask)
im = ax.imshow(norm_res * likelihood_mask, cmap=self.cmap_resid, vmin=-4, vmax=4, extent=extent)
im = ax.imshow(residuals * likelihood_mask, cmap=self.cmap_resid, extent=extent, norm=self.norm_res)
ax.set_title(r"(f${}_{\rm model}$ - f${}_{\rm data})/\sigma$", fontsize=self._base_fs)
nice_colorbar_residuals(im, norm_res, position='top', pad=0.4, size=0.2,
vmin=-4, vmax=4,
nice_colorbar_residuals(im, residuals, position='top', pad=0.4, size=0.2,
vmin=self.norm_res.vmin, vmax=self.norm_res.vmax,
colorbar_kwargs={'orientation': 'horizontal'})
text = r"$\chi^2={:.2f}$".format(red_chi2)
ax.text(0.05, 0.05, text, color='black', # fontsize=,
Expand Down Expand Up @@ -202,7 +203,7 @@ def model_summary(self, lens_image, kwargs_result,
diff = source_model - true_source
vmax_diff = true_source.max() / 10.
im = ax.imshow(diff, extent=src_extent,
cmap=self.cmap_resid, vmin=-vmax_diff, vmax=vmax_diff)
cmap=self.cmap_resid, norm=Normalize(-vmax_diff, vmax_diff))
ax.set_title(r"s${}_{\rm model}$ - s${}_{\rm truth}$", fontsize=self._base_fs)
nice_colorbar_residuals(im, diff, position='top', pad=0.4, size=0.2,
vmin=-vmax_diff, vmax=vmax_diff,
Expand Down
4 changes: 2 additions & 2 deletions jaxtronomy/Coordinates/pixel_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def model_pixel_extent(self, name):

def create_model_grid(self, grid_center=None, grid_shape=None,
pixel_scale_factor=None, conserve_extent=False,
name='none'):
name='none', overwrite=False):
"""
:param grid_center: 2-tuple (center_x, center_y) with grid center in physical units
If None, defaults to the original grid center.
Expand All @@ -134,7 +134,7 @@ def create_model_grid(self, grid_center=None, grid_shape=None,
Otherwise, the extent will be computed based on the final pixel width.
:param name: unique string for identifying the created grid.
"""
if name in self._model_grids:
if not overwrite and name in self._model_grids:
raise ValueError(f"Grid name '{name}' is already used for another grid")

unchanged_count = 0
Expand Down
428 changes: 138 additions & 290 deletions notebooks/fit_pixelated_GRF_potential.ipynb

Large diffs are not rendered by default.

394 changes: 159 additions & 235 deletions notebooks/fit_pixelated_compact_potential.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"# Find a localised dark matter halo along with the source\n",
"### Fit pixelated lens potential perturbations and a pixelated source, assuming a known SIE model\n",
"\n",
"__last update__: 28/07/21"
"__last update__: 28/07/21 (not compatible with latest updates!)"
]
},
{
Expand Down