Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify constructor of classes that inherit from the Parameters class #2103

Merged
merged 4 commits into from
Nov 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taxcalc/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, policy=None, records=None, verbose=True,
if self.__policy.current_year < self.__records.data_year:
self.__policy.set_year(self.__records.data_year)
if consumption is None:
self.__consumption = Consumption(start_year=policy.start_year)
self.__consumption = Consumption()
elif isinstance(consumption, Consumption):
self.__consumption = copy.deepcopy(consumption)
else:
Expand Down
23 changes: 4 additions & 19 deletions taxcalc/consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,7 @@ class Consumption(Parameters):

Parameters
----------
start_year: integer
first calendar year for consumption parameters.

num_years: integer
number of calendar years for which to specify parameter
values beginning with start_year.

Raises
------
ValueError:
if start_year is less than Policy.JSON_START_YEAR.
if num_years is less than one.
none

Returns
-------
Expand All @@ -41,15 +30,11 @@ class instance: Consumption
DEFAULTS_FILENAME = 'consumption.json'
DEFAULT_NUM_YEARS = Policy.DEFAULT_NUM_YEARS

def __init__(self,
start_year=JSON_START_YEAR,
num_years=DEFAULT_NUM_YEARS):
def __init__(self):
super(Consumption, self).__init__()
self._vals = self._params_dict_from_json_file()
if start_year < Policy.JSON_START_YEAR:
raise ValueError('start_year < Policy.JSON_START_YEAR')
if num_years < 1:
raise ValueError('num_years < 1')
start_year = Consumption.JSON_START_YEAR
num_years = Consumption.DEFAULT_NUM_YEARS
self.initialize(start_year, num_years)
self.parameter_errors = ''

Expand Down
23 changes: 4 additions & 19 deletions taxcalc/growdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,7 @@ class GrowDiff(Parameters):

Parameters
----------
start_year: integer
first calendar year for growth difference parameters.

num_years: integer
number of calendar years for which to specify parameter
values beginning with start_year.

Raises
------
ValueError:
if start_year is less than 2013
if num_years is less than one.
none

Returns
-------
Expand All @@ -40,15 +29,11 @@ class instance: GrowDiff
DEFAULTS_FILENAME = 'growdiff.json'
DEFAULT_NUM_YEARS = 15 # must be same as Policy.DEFAULT_NUM_YEARS

def __init__(self,
start_year=JSON_START_YEAR,
num_years=DEFAULT_NUM_YEARS):
def __init__(self):
super(GrowDiff, self).__init__()
self._vals = self._params_dict_from_json_file()
if start_year < GrowDiff.JSON_START_YEAR:
raise ValueError('start_year < GrowDiff.JSON_START_YEAR')
if num_years < 1:
raise ValueError('num_years < 1')
start_year = GrowDiff.JSON_START_YEAR
num_years = GrowDiff.DEFAULT_NUM_YEARS
self.initialize(start_year, num_years)
self.parameter_errors = ''

Expand Down
64 changes: 3 additions & 61 deletions taxcalc/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,20 @@ class Parameters():
DEFAULTS_FILENAME = None

@classmethod
def default_data(cls, metadata=False, start_year=None):
def default_data(cls, metadata=False):
"""
Return parameter data read from the subclass's json file.

Parameters
----------
metadata: boolean

start_year: int or None

Returns
-------
params: dictionary of data
"""
# extract different data from DEFAULT_FILENAME depending on start_year
if start_year is None:
params = cls._params_dict_from_json_file()
else:
nyrs = start_year - cls.JSON_START_YEAR + 1
ppo = cls(num_years=nyrs)
ppo.set_year(start_year)
params = getattr(ppo, '_vals')
params = Parameters._revised_default_data(params, start_year,
nyrs, ppo)
# extract data from DEFAULT_FILENAME
params = cls._params_dict_from_json_file()
# return different data from params dict depending on metadata value
if metadata:
return params
Expand Down Expand Up @@ -316,54 +306,6 @@ def _validate_assump_parameter_values(self, parameters_set):
)
del parameters

@staticmethod
def _revised_default_data(params, start_year, nyrs, ppo):
"""
Return revised default parameter data.

Parameters
----------
params: dictionary of NAME:DATA pairs for each parameter
as defined in calling default_data staticmethod.

start_year: int
as defined in calling default_data staticmethod.

nyrs: int
as defined in calling default_data staticmethod.

ppo: Policy object
as defined in calling default_data staticmethod.

Returns
-------
params: dictionary of revised parameter data

Notes
-----
This staticmethod is called from default_data staticmethod in
order to reduce the complexity of the default_data staticmethod.
"""
start_year_str = '{}'.format(start_year)
for name, data in params.items():
data['start_year'] = start_year
values = data['value']
num_values = len(values)
if num_values <= nyrs:
# val should be the single start_year value
rawval = getattr(ppo, name[1:])
if isinstance(rawval, np.ndarray):
val = rawval.tolist()
else:
val = rawval
data['value'] = [val]
data['row_label'] = [start_year_str]
else: # if num_values > nyrs
# val should extend beyond the start_year value
data['value'] = data['value'][(nyrs - 1):]
data['row_label'] = data['row_label'][(nyrs - 1):]
return params

@classmethod
def _params_dict_from_json_file(cls):
"""
Expand Down
34 changes: 8 additions & 26 deletions taxcalc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,10 @@ class Policy(Parameters):
gfactors: GrowFactors class instance
containing price inflation rates and wage growth rates

start_year: integer
first calendar year for historical policy parameters.

num_years: integer
number of calendar years for which to specify policy parameter
values beginning with start_year.

Raises
------
ValueError:
if gfactors is not a GrowFactors class instance.
if start_year is less than JSON_START_YEAR.
if num_years is less than one.
if gfactors is not a GrowFactors class instance or None.

Returns
-------
Expand All @@ -50,10 +41,7 @@ class instance: Policy
# should increase LAST_BUDGET_YEAR by one every calendar year
DEFAULT_NUM_YEARS = LAST_BUDGET_YEAR - JSON_START_YEAR + 1

def __init__(self,
gfactors=None,
start_year=JSON_START_YEAR,
num_years=DEFAULT_NUM_YEARS):
def __init__(self, gfactors=None):
super(Policy, self).__init__()

if gfactors is None:
Expand All @@ -63,21 +51,15 @@ def __init__(self,
else:
raise ValueError('gfactors is not None or a GrowFactors instance')

# read default parameters
# read default parameters and initialize
self._vals = self._params_dict_from_json_file()

if start_year < Policy.JSON_START_YEAR:
raise ValueError('start_year cannot be less than JSON_START_YEAR')
if num_years < 1:
raise ValueError('num_years cannot be less than one')

syr = start_year
lyr = start_year + num_years - 1
syr = Policy.JSON_START_YEAR
lyr = Policy.LAST_BUDGET_YEAR
nyrs = Policy.DEFAULT_NUM_YEARS
self._inflation_rates = self._gfactors.price_inflation_rates(syr, lyr)
self._apply_clp_cpi_offset(self._vals['_cpi_offset'], num_years)
self._apply_clp_cpi_offset(self._vals['_cpi_offset'], nyrs)
self._wage_growth_rates = self._gfactors.wage_growth_rates(syr, lyr)

self.initialize(start_year, num_years)
self.initialize(syr, nyrs)

self.parameter_warnings = ''
self.parameter_errors = ''
Expand Down
19 changes: 10 additions & 9 deletions taxcalc/tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,18 @@ def fixture_policyfile():


def test_make_calculator(cps_subsample):
syr = 2014
pol = Policy(start_year=syr, num_years=9)
assert pol.current_year == syr
start_year = Policy.JSON_START_YEAR
sim_year = 2018
pol = Policy()
assert pol.current_year == start_year
rec = Records.cps_constructor(data=cps_subsample)
consump = Consumption()
consump.update_consumption({syr: {'_MPC_e20400': [0.05]}})
assert consump.current_year == Consumption.JSON_START_YEAR
consump.update_consumption({sim_year: {'_MPC_e20400': [0.05]}})
assert consump.current_year == start_year
calc = Calculator(policy=pol, records=rec,
consumption=consump, behavior=Behavior())
assert calc.current_year == syr
assert calc.records_current_year() == syr
assert calc.current_year == Records.CPSCSV_YEAR
assert calc.records_current_year() == Records.CPSCSV_YEAR
# test incorrect Calculator instantiation:
with pytest.raises(ValueError):
Calculator(policy=None, records=rec)
Expand Down Expand Up @@ -224,8 +225,7 @@ def test_calculator_mtr_when_PT_rates_differ():

def test_make_calculator_increment_years_first(cps_subsample):
# create Policy object with policy reform
syr = 2013
pol = Policy(start_year=syr)
pol = Policy()
reform = {2015: {}, 2016: {}}
std5 = 2000
reform[2015]['_STD_Aged'] = [[std5, std5, std5, std5, std5]]
Expand All @@ -238,6 +238,7 @@ def test_make_calculator_increment_years_first(cps_subsample):
calc = Calculator(policy=pol, records=rec)
# compare expected policy parameter values with those embedded in calc
irates = pol.inflation_rates()
syr = Policy.JSON_START_YEAR
irate2015 = irates[2015 - syr]
irate2016 = irates[2016 - syr]
std6 = std5 * (1.0 + irate2015)
Expand Down
10 changes: 4 additions & 6 deletions taxcalc/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from taxcalc import Policy, Records, Calculator, Consumption


def test_incorrect_Consumption_instantiation():
with pytest.raises(ValueError):
consump = Consumption(start_year=2000)
with pytest.raises(ValueError):
consump = Consumption(num_years=0)
def test_year_consistency():
assert Consumption.JSON_START_YEAR == Policy.JSON_START_YEAR
assert Consumption.DEFAULT_NUM_YEARS == Policy.DEFAULT_NUM_YEARS


def test_validity_of_consumption_vars_set():
Expand All @@ -22,7 +20,7 @@ def test_validity_of_consumption_vars_set():


def test_update_consumption():
consump = Consumption(start_year=2013)
consump = Consumption()
consump.update_consumption({})
consump.update_consumption({2014: {'_MPC_e20400': [0.05],
'_BEN_mcare_value': [0.75]},
Expand Down
32 changes: 13 additions & 19 deletions taxcalc/tests/test_growdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os
import json
from numpy.testing import assert_allclose
import numpy as np
import pytest
from taxcalc import GrowDiff, GrowFactors, Policy

Expand All @@ -13,26 +13,19 @@ def test_year_consistency():
assert GrowDiff.DEFAULT_NUM_YEARS == Policy.DEFAULT_NUM_YEARS


def test_incorrect_growdiff_ctor():
with pytest.raises(ValueError):
gdiff = GrowDiff(start_year=2000)
with pytest.raises(ValueError):
gdiff = GrowDiff(num_years=0)


def test_update_and_apply_growdiff():
syr = 2013
nyrs = 5
lyr = syr + nyrs - 1
gdiff = GrowDiff(start_year=syr, num_years=nyrs)
gdiff = GrowDiff()
# update GrowDiff instance
diffs = {2014: {'_AWAGE': [0.01]},
2016: {'_AWAGE': [0.02]}}
gdiff.update_growdiff(diffs)
expected_wage_diffs = [0.00, 0.01, 0.01, 0.02, 0.02]
assert_allclose(gdiff._AWAGE, expected_wage_diffs, atol=0.0, rtol=0.0)
expected_wage_diffs = [0.00, 0.01, 0.01, 0.02, 0.02] + [0.02]*10
assert np.allclose(gdiff._AWAGE, expected_wage_diffs, atol=0.0, rtol=0.0)
# apply growdiff to GrowFactors instance
gf = GrowFactors()
syr = GrowDiff.JSON_START_YEAR
nyrs = GrowDiff.DEFAULT_NUM_YEARS
lyr = syr + nyrs - 1
pir_pre = gf.price_inflation_rates(syr, lyr)
wgr_pre = gf.wage_growth_rates(syr, lyr)
gfactors = GrowFactors()
Expand All @@ -41,8 +34,8 @@ def test_update_and_apply_growdiff():
wgr_pst = gfactors.wage_growth_rates(syr, lyr)
expected_wgr_pst = [wgr_pre[i] + expected_wage_diffs[i]
for i in range(0, nyrs)]
assert_allclose(pir_pre, pir_pst, atol=0.0, rtol=0.0)
assert_allclose(wgr_pst, expected_wgr_pst, atol=1.0e-9, rtol=0.0)
assert np.allclose(pir_pre, pir_pst, atol=0.0, rtol=0.0)
assert np.allclose(wgr_pst, expected_wgr_pst, atol=1.0e-9, rtol=0.0)


def test_incorrect_update_growdiff():
Expand All @@ -61,11 +54,12 @@ def test_incorrect_update_growdiff():


def test_has_any_response():
syr = 2014
gdiff = GrowDiff(start_year=syr)
start_year = GrowDiff.JSON_START_YEAR
gdiff = GrowDiff()
assert gdiff.current_year == start_year
assert gdiff.has_any_response() is False
gdiff.update_growdiff({2020: {'_AWAGE': [0.01]}})
assert gdiff.current_year == syr
assert gdiff.current_year == start_year
assert gdiff.has_any_response() is True


Expand Down
Loading