Skip to content

Commit

Permalink
Merge pull request #978 from jdebacker/colormap
Browse files Browse the repository at this point in the history
Merging
  • Loading branch information
rickecon authored Sep 18, 2024
2 parents 92f7594 + 8966208 commit 022d3b4
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
12 changes: 8 additions & 4 deletions ogcore/output_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from mpl_toolkits.mplot3d import Axes3D
import matplotlib
from ogcore.constants import (
Expand Down Expand Up @@ -391,7 +392,7 @@ def ss_3Dplot(
data = (reform_ss[var] - base_ss[var]).T
elif plot_type == "pct_diff":
data = ((reform_ss[var] - base_ss[var]) / base_ss[var]).T
cmap1 = matplotlib.cm.get_cmap("jet")
cmap1 = matplotlib.colormaps.get_cmap("jet")
X, Y = np.meshgrid(domain, Jgrid)
fig5, ax5 = plt.subplots(subplot_kw={"projection": "3d"})
ax5.set_xlabel(r"age-$s$")
Expand Down Expand Up @@ -652,7 +653,7 @@ def ability_bar_ss(
plt.ylabel(r"Percentage Change in " + VAR_LABELS[var])
if plot_title:
plt.title(plot_title, fontsize=15)
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
# plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
Expand Down Expand Up @@ -1199,14 +1200,17 @@ def inequality_plot(
plt.title(plot_title, fontsize=15)
vals = ax1.get_yticks()
if plot_type == "pct_diff":
ax1.set_yticklabels(["{:,.2%}".format(x) for x in vals])
ticks_loc = ax1.get_yticks().tolist()
ax1.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax1.set_yticklabels(["{:,.2%}".format(x) for x in ticks_loc])
plt.xlim(
(
base_params.start_year - 1,
base_params.start_year + num_years_to_plot,
)
)
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if plot_type == "levels":
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
Expand Down
21 changes: 13 additions & 8 deletions ogcore/parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.ticker as mticker
from ogcore.constants import GROUP_LABELS
from ogcore import utils, txfunc
from ogcore.constants import DEFAULT_START_YEAR, VAR_LABELS
Expand Down Expand Up @@ -107,8 +108,9 @@ def plot_mort_rates(
plt.ylabel(r"Mortality Rates $\rho_{s}$")
plt.legend(loc="upper left")
title = "Mortality Rates"
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.0%}".format(x) for x in vals])
ticks_loc = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_yticklabels(["{:,.0%}".format(x) for x in ticks_loc])
if include_title:
plt.title(title)
if path is None:
Expand Down Expand Up @@ -150,8 +152,9 @@ def plot_pop_growth(
plt.plot(year_vec, p.g_n[start_index : start_index + num_years_to_plot])
plt.xlabel(r"Year $t$")
plt.ylabel(r"Population Growth Rate $g_{n, t}$")
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.2%}".format(x) for x in vals])
ticks_loc = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_yticklabels(["{:,.2%}".format(x) for x in ticks_loc])
if include_title:
plt.title("Population Growth Rates")
if path is None:
Expand Down Expand Up @@ -485,9 +488,11 @@ def plot_g_n(p_list, label_list=[""], include_title=False, path=None):
plt.plot(years, p.g_n[: p.T], label=label_list[i])
plt.xlabel(r"Year $s$ (model periods)")
plt.ylabel(r"Population Growth Rate $g_{n,t}$")
plt.legend(loc="upper right")
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.0%}".format(x) for x in vals])
if label_list[0] != "":
plt.legend(loc="upper right")
ticks_loc = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_yticklabels(["{:,.0%}".format(x) for x in ticks_loc])
if include_title:
plt.title("Population Growth Rates")
if path is None:
Expand Down Expand Up @@ -972,7 +977,7 @@ def plot_income_data(
t = -1
J = abil_midp.shape[0]
abil_mesh, age_mesh = np.meshgrid(abil_midp, ages)
cmap1 = matplotlib.cm.get_cmap("summer")
cmap1 = matplotlib.colormaps["summer"]
if path:
# Make sure that directory is created
utils.mkdirs(path)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from ogcore import utils, output_plots, constants


Expand Down Expand Up @@ -166,6 +167,7 @@ def test_plot_aggregates(
plot_title=plot_title,
)
assert fig
plt.close()


test_data = [
Expand Down Expand Up @@ -217,6 +219,7 @@ def test_plot_industry_aggregates(
plot_title=plot_title,
)
assert fig
plt.close()


test_data = [
Expand Down Expand Up @@ -300,6 +303,7 @@ def test_plot_gdp_ratio(
plot_title=plot_title,
)
assert fig
plt.close()


def test_plot_gdp_ratio_save_fig(tmpdir):
Expand Down Expand Up @@ -327,6 +331,7 @@ def test_ability_bar():
plot_title=" Test Plot Title",
)
assert fig
plt.close()


def test_ability_bar_save_fig(tmpdir):
Expand All @@ -353,6 +358,7 @@ def test_ability_bar_ss():
plot_title=" Test Plot Title",
)
assert fig
plt.close()


data_for_plot = np.ones(80) * 0.3
Expand All @@ -374,6 +380,7 @@ def test_ss_profiles(by_j, plot_data):
plot_title=" Test Plot Title",
)
assert fig
plt.close()


def test_ss_profiles_save_fig(tmpdir):
Expand All @@ -398,6 +405,7 @@ def test_tpi_profiles(by_j):
plot_title=" Test Plot Title",
)
assert fig
plt.close()


test_data = [
Expand Down Expand Up @@ -454,6 +462,7 @@ def test_ss_3Dplot(
plot_title=plot_title,
)
assert fig
plt.close()


def test_ss_3Dplot_save_fig(tmpdir):
Expand Down Expand Up @@ -540,6 +549,7 @@ def test_inequality_plot(
plot_type=plot_type,
)
assert fig
plt.close()


def test_inequality_plot_save_fig(tmpdir):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate as si
import matplotlib.image as mpimg
from ogcore import utils, parameter_plots, Specifications
Expand Down Expand Up @@ -74,13 +75,15 @@ def test_plot_imm_rates_save_fig(tmpdir):
def test_plot_mort_rates():
fig = parameter_plots.plot_mort_rates([base_params], include_title=True)
assert fig
plt.close()


def test_plot_surv_rates():
fig = parameter_plots.plot_mort_rates(
[base_params], survival_rates=True, include_title=True
)
assert fig
plt.close()


def test_plot_mort_rates_save_fig(tmpdir):
Expand All @@ -104,6 +107,7 @@ def test_plot_pop_growth():
base_params, start_year=int(base_params.start_year), include_title=True
)
assert fig
plt.close()


def test_plot_pop_growth_rates_save_fig(tmpdir):
Expand All @@ -119,6 +123,7 @@ def test_plot_ability_profiles():
p = Specifications()
fig = parameter_plots.plot_ability_profiles(p, p2=p, include_title=True)
assert fig
plt.close()


def test_plot_log_ability_profiles():
Expand All @@ -127,6 +132,7 @@ def test_plot_log_ability_profiles():
p, p2=p, log_scale=True, include_title=True
)
assert fig
plt.close()


def test_plot_ability_profiles_save_fig(tmpdir):
Expand All @@ -144,6 +150,7 @@ def test_plot_elliptical_u():
)
assert fig1
assert fig2
plt.close()


def test_plot_elliptical_u_save_fig(tmpdir):
Expand All @@ -157,6 +164,7 @@ def test_plot_chi_n():
p = Specifications()
fig = parameter_plots.plot_chi_n([p], include_title=True)
assert fig
plt.close()


def test_plot_chi_n_save_fig(tmpdir):
Expand All @@ -177,6 +185,7 @@ def test_plot_population(years_to_plot):
base_params, years_to_plot=years_to_plot, include_title=True
)
assert fig
plt.close()


def test_plot_population_save_fig(tmpdir):
Expand Down Expand Up @@ -215,6 +224,7 @@ def test_plot_fert_rates():
fert_rates = np.random.uniform(size=totpers).reshape((1, totpers))
fig = parameter_plots.plot_fert_rates([fert_rates], include_title=True)
assert fig
plt.close()


def test_plot_fert_rates_save_fig(tmpdir):
Expand Down Expand Up @@ -258,6 +268,7 @@ def test_plot_g_n():
p = Specifications()
fig = parameter_plots.plot_g_n([p], include_title=True)
assert fig
plt.close()


def test_plot_g_n_savefig(tmpdir):
Expand All @@ -276,6 +287,7 @@ def test_plot_mort_rates_data():
path=None,
)
assert fig
plt.close()


def test_plot_mort_rates_data_save_fig(tmpdir):
Expand All @@ -300,6 +312,7 @@ def test_plot_omega_fixed():
age_per_EpS, omega_SS_orig, omega_SSfx, E, S
)
assert fig
plt.close()


def test_plot_omega_fixed_save_fig(tmpdir):
Expand All @@ -326,6 +339,7 @@ def test_plot_imm_fixed():
age_per_EpS, imm_rates_orig, imm_rates_adj, E, S
)
assert fig
plt.close()


def test_plot_imm_fixed_save_fig(tmpdir):
Expand Down Expand Up @@ -360,6 +374,7 @@ def test_plot_population_path():
S,
)
assert fig
plt.close()


def test_plot_population_path_save_fig(tmpdir):
Expand Down Expand Up @@ -398,6 +413,7 @@ def test_plot_income_data():
fig = parameter_plots.plot_income_data(ages, abil_midp, abil_pcts, emat)

assert fig
plt.close()


def test_plot_income_data_save_fig(tmpdir):
Expand Down Expand Up @@ -481,6 +497,7 @@ def test_plot_2D_taxfunc(
)

assert fig
plt.close()
else:
assert True

Expand Down

0 comments on commit 022d3b4

Please sign in to comment.