Skip to content

Commit

Permalink
Merge pull request #1588 from martinholmer/revise-calculator-usage
Browse files Browse the repository at this point in the history
Revise Calculator usage in unit tests
  • Loading branch information
martinholmer authored Oct 17, 2017
2 parents 6e69cbe + fd75921 commit 30c4567
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 199 deletions.
12 changes: 5 additions & 7 deletions taxcalc/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,11 @@ def current_law_version(self):
"""
Return Calculator object same as self except with current-law policy.
"""
clp = self.policy.current_law_version()
recs = copy.deepcopy(self.records)
cons = copy.deepcopy(self.consumption)
behv = copy.deepcopy(self.behavior)
calc = Calculator(policy=clp, records=recs, sync_years=False,
consumption=cons, behavior=behv)
return calc
return Calculator(policy=self.policy.current_law_version(),
records=copy.deepcopy(self.records),
sync_years=False,
consumption=copy.deepcopy(self.consumption),
behavior=copy.deepcopy(self.behavior))

@staticmethod
def read_json_param_objects(reform, assump):
Expand Down
180 changes: 87 additions & 93 deletions taxcalc/tests/test_calculate.py

Large diffs are not rendered by default.

25 changes: 11 additions & 14 deletions taxcalc/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,24 @@ def test_consumption_response(cps_subsample):
with pytest.raises(ValueError):
consump.response(list(), 1)
# test correct call to response method
recs = Records.cps_constructor(data=cps_subsample)
pre = copy.deepcopy(recs.e20400)
consump.response(recs, 1.0)
post = recs.e20400
rec = Records.cps_constructor(data=cps_subsample)
pre = copy.deepcopy(rec.e20400)
consump.response(rec, 1.0)
post = rec.e20400
actual_diff = post - pre
expected_diff = np.ones(recs.dim) * mpc
expected_diff = np.ones(rec.dim) * mpc
assert np.allclose(actual_diff, expected_diff)
# compute earnings mtr with no consumption response
recs0 = Records.cps_constructor(data=cps_subsample)
calc0 = Calculator(policy=Policy(), records=recs0, consumption=None)
ided0 = copy.deepcopy(recs0.e20400)
rec = Records.cps_constructor(data=cps_subsample)
ided0 = copy.deepcopy(rec.e20400)
calc0 = Calculator(policy=Policy(), records=rec, consumption=None)
(mtr0_ptax, mtr0_itax, _) = calc0.mtr(variable_str='e00200p',
wrt_full_compensation=False)
assert np.allclose(calc0.records.e20400, ided0)
# compute earnings mtr with consumption response
recs1 = Records.cps_constructor(data=cps_subsample)
calc1 = Calculator(policy=Policy(), records=recs1, consumption=None)
assert np.allclose(calc1.records.e20400, ided0)
calc1.consumption.update_consumption(consumption_response)
(mtr1_ptax, mtr1_itax, _) = calc1.mtr(variable_str='e00200p',
wrt_full_compensation=False)
calc1 = Calculator(policy=Policy(), records=rec, consumption=consump)
mtr1_ptax, mtr1_itax, _ = calc1.mtr(variable_str='e00200p',
wrt_full_compensation=False)
assert np.allclose(calc1.records.e20400, ided0)
# confirm that payroll mtr values are no different
assert np.allclose(mtr1_ptax, mtr0_ptax)
Expand Down
12 changes: 6 additions & 6 deletions taxcalc/tests/test_macro_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def test_proportional_change_in_gdp(cps_subsample):
"""
Test correct and incorrect calls to proportional_change_in_gdp function.
"""
rec1 = Records.cps_constructor(data=cps_subsample)
calc1 = Calculator(policy=Policy(), records=rec1)
rec2 = Records.cps_constructor(data=cps_subsample)
pol2 = Policy()
rec = Records.cps_constructor(data=cps_subsample)
pol = Policy()
calc1 = Calculator(policy=pol, records=rec)
reform = {2015: {'_II_em': [0.0]}} # reform increases taxes and MTRs
pol2.implement_reform(reform)
calc2 = Calculator(policy=pol2, records=rec2)
pol.implement_reform(reform)
calc2 = Calculator(policy=pol, records=rec)
assert calc1.current_year == calc2.current_year
assert calc1.current_year == 2014 # because using CPS data
gdpc = proportional_change_in_gdp(2014, calc1, calc2, elasticity=0.36)
assert gdpc == 0.0 # no effect for first data year
Expand Down
26 changes: 10 additions & 16 deletions taxcalc/tests/test_pufcsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,36 +235,30 @@ def test_mtr(tests_path, puf_path):


@pytest.mark.requires_pufcsv
def test_credit_reforms(puf_path):
def test_credit_reforms(puf_subsample):
"""
Test personal credit reforms using small puf.csv sub-sample
Test personal credit reforms using puf.csv subsample
"""
# pylint: disable=too-many-locals
rec = Records(data=puf_subsample)
reform_year = 2017
fullsample = pd.read_csv(puf_path)
subsample = fullsample.sample(frac=0.05, # pylint: disable=no-member
random_state=180)
# create current-law Calculator object, calc1
recs1 = Records(data=subsample)
calc1 = Calculator(policy=Policy(), records=recs1)
pol = Policy()
calc1 = Calculator(policy=pol, records=rec)
calc1.advance_to_year(reform_year)
calc1.calc_all()
itax1 = (calc1.records.iitax * calc1.records.s006).sum()
# create personal-refundable-credit-reform Calculator object, calc2
recs2 = Records(data=subsample)
policy2 = Policy()
reform = {reform_year: {'_II_credit': [[1000, 1000, 1000, 1000, 1000]]}}
policy2.implement_reform(reform)
calc2 = Calculator(policy=policy2, records=recs2)
pol.implement_reform(reform)
calc2 = Calculator(policy=pol, records=rec)
calc2.advance_to_year(reform_year)
calc2.calc_all()
itax2 = (calc2.records.iitax * calc2.records.s006).sum()
# create personal-nonrefundable-credit-reform Calculator object, calc3
recs3 = Records(data=subsample)
policy3 = Policy()
reform = {reform_year: {'_II_credit_nr': [[1000, 1000, 1000, 1000, 1000]]}}
policy3.implement_reform(reform)
calc3 = Calculator(policy=policy3, records=recs3)
pol = Policy()
pol.implement_reform(reform)
calc3 = Calculator(policy=pol, records=rec)
calc3.advance_to_year(reform_year)
calc3.calc_all()
itax3 = (calc3.records.iitax * calc3.records.s006).sum()
Expand Down
16 changes: 7 additions & 9 deletions taxcalc/tests/test_reforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,20 @@ def reform_results(reform_dict, puf_data):
"""
# pylint: disable=too-many-locals
# create current-law-policy Calculator object
pol1 = Policy()
rec1 = Records(data=puf_data)
calc1 = Calculator(policy=pol1, records=rec1, verbose=False, behavior=None)
pol = Policy()
rec = Records(data=puf_data)
calc1 = Calculator(policy=pol, records=rec, verbose=False, behavior=None)
# create reform Calculator object with possible behavioral responses
start_year = reform_dict['start_year']
beh2 = Behavior()
beh = Behavior()
if '_BE_cg' in reform_dict['value']:
elasticity = reform_dict['value']['_BE_cg']
del reform_dict['value']['_BE_cg'] # in order to have a valid reform
beh_assump = {start_year: {'_BE_cg': elasticity}}
beh2.update_behavior(beh_assump)
beh.update_behavior(beh_assump)
reform = {start_year: reform_dict['value']}
pol2 = Policy()
pol2.implement_reform(reform)
rec2 = Records(data=puf_data)
calc2 = Calculator(policy=pol2, records=rec2, verbose=False, behavior=beh2)
pol.implement_reform(reform)
calc2 = Calculator(policy=pol, records=rec, verbose=False, behavior=beh)
# increment both calculators to reform's start_year
calc1.advance_to_year(start_year)
calc2.advance_to_year(start_year)
Expand Down
100 changes: 46 additions & 54 deletions taxcalc/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,17 @@ def test_validity_of_name_lists():

def test_create_tables(cps_subsample):
# create a current-law Policy object and Calculator object calc1
policy1 = Policy()
records1 = Records.cps_constructor(data=cps_subsample)
calc1 = Calculator(policy=policy1, records=records1)
rec = Records.cps_constructor(data=cps_subsample)
pol = Policy()
calc1 = Calculator(policy=pol, records=rec)
calc1.calc_all()
# create a policy-reform Policy object and Calculator object calc2
reform = {2013: {'_II_rt1': [0.15]}}
policy2 = Policy()
policy2.implement_reform(reform)
records2 = Records.cps_constructor(data=cps_subsample)
calc2 = Calculator(policy=policy2, records=records2)
pol.implement_reform(reform)
calc2 = Calculator(policy=pol, records=rec)
calc2.calc_all()

# test creating various difference tables

diff = create_difference_table(calc1.records, calc2.records,
groupby='large_income_bins',
income_measure='expanded_income',
Expand Down Expand Up @@ -574,8 +572,8 @@ def test_add_quantile_bins():


def test_dist_table_sum_row(cps_subsample):
recs = Records.cps_constructor(data=cps_subsample)
calc = Calculator(policy=Policy(), records=recs)
rec = Records.cps_constructor(data=cps_subsample)
calc = Calculator(policy=Policy(), records=rec)
calc.calc_all()
tb1 = create_distribution_table(calc.records,
groupby='small_income_bins',
Expand All @@ -594,17 +592,15 @@ def test_dist_table_sum_row(cps_subsample):


def test_diff_table_sum_row(cps_subsample):
rec = Records.cps_constructor(data=cps_subsample)
# create a current-law Policy object and Calculator calc1
policy1 = Policy()
records1 = Records.cps_constructor(data=cps_subsample)
calc1 = Calculator(policy=policy1, records=records1)
pol = Policy()
calc1 = Calculator(policy=pol, records=rec)
calc1.calc_all()
# create a policy-reform Policy object and Calculator calc2
reform = {2013: {'_II_rt4': [0.56]}}
policy2 = Policy()
policy2.implement_reform(reform)
records2 = Records.cps_constructor(data=cps_subsample)
calc2 = Calculator(policy=policy2, records=records2)
pol.implement_reform(reform)
calc2 = Calculator(policy=pol, records=rec)
calc2.calc_all()
# create two difference tables and compare their content
tdiff1 = create_difference_table(calc1.records, calc2.records,
Expand All @@ -628,21 +624,21 @@ def test_mtr_graph_data(cps_subsample):
calc = Calculator(policy=Policy(),
records=Records.cps_constructor(data=cps_subsample))
with pytest.raises(ValueError):
gdata = mtr_graph_data(calc, calc, mars='bad',
income_measure='agi',
dollar_weighting=True)
mtr_graph_data(calc, calc, mars='bad',
income_measure='agi',
dollar_weighting=True)
with pytest.raises(ValueError):
gdata = mtr_graph_data(calc, calc, mars=0,
income_measure='expanded_income',
dollar_weighting=True)
mtr_graph_data(calc, calc, mars=0,
income_measure='expanded_income',
dollar_weighting=True)
with pytest.raises(ValueError):
gdata = mtr_graph_data(calc, calc, mars=list())
mtr_graph_data(calc, calc, mars=list())
with pytest.raises(ValueError):
gdata = mtr_graph_data(calc, calc, mars='ALL', mtr_variable='e00200s')
mtr_graph_data(calc, calc, mars='ALL', mtr_variable='e00200s')
with pytest.raises(ValueError):
gdata = mtr_graph_data(calc, calc, mtr_measure='badtax')
mtr_graph_data(calc, calc, mtr_measure='badtax')
with pytest.raises(ValueError):
gdata = mtr_graph_data(calc, calc, income_measure='badincome')
mtr_graph_data(calc, calc, income_measure='badincome')
gdata = mtr_graph_data(calc, calc, mars=1,
mtr_wrt_full_compen=True,
income_measure='wages',
Expand All @@ -651,25 +647,25 @@ def test_mtr_graph_data(cps_subsample):


def test_atr_graph_data(cps_subsample):
calc = Calculator(policy=Policy(),
records=Records.cps_constructor(data=cps_subsample))
pol = Policy()
rec = Records.cps_constructor(data=cps_subsample)
calc = Calculator(policy=pol, records=rec)
with pytest.raises(ValueError):
gdata = atr_graph_data(calc, calc, mars='bad')
atr_graph_data(calc, calc, mars='bad')
with pytest.raises(ValueError):
gdata = atr_graph_data(calc, calc, mars=0)
atr_graph_data(calc, calc, mars=0)
with pytest.raises(ValueError):
gdata = atr_graph_data(calc, calc, mars=list())
atr_graph_data(calc, calc, mars=list())
with pytest.raises(ValueError):
gdata = atr_graph_data(calc, calc, atr_measure='badtax')
atr_graph_data(calc, calc, atr_measure='badtax')
gdata = atr_graph_data(calc, calc, mars=1, atr_measure='combined')
gdata = atr_graph_data(calc, calc, atr_measure='itax')
gdata = atr_graph_data(calc, calc, atr_measure='ptax')
assert isinstance(gdata, dict)
with pytest.raises(ValueError):
calcx = Calculator(policy=Policy(),
records=Records.cps_constructor(data=cps_subsample))
calcx = Calculator(policy=pol, records=rec)
calcx.advance_to_year(2020)
gdata = atr_graph_data(calcx, calc)
atr_graph_data(calcx, calc)


def test_xtr_graph_plot(cps_subsample):
Expand Down Expand Up @@ -721,32 +717,30 @@ def test_write_graph_file(cps_subsample):


def test_multiyear_diagnostic_table(cps_subsample):
behv = Behavior()
calc = Calculator(policy=Policy(),
records=Records.cps_constructor(data=cps_subsample),
behavior=behv)
rec = Records.cps_constructor(data=cps_subsample)
pol = Policy()
beh = Behavior()
calc = Calculator(policy=pol, records=rec, behavior=beh)
with pytest.raises(ValueError):
adt = multiyear_diagnostic_table(calc, 0)
multiyear_diagnostic_table(calc, 0)
with pytest.raises(ValueError):
adt = multiyear_diagnostic_table(calc, 20)
multiyear_diagnostic_table(calc, 20)
adt = multiyear_diagnostic_table(calc, 3)
assert isinstance(adt, pd.DataFrame)
behv.update_behavior({2013: {'_BE_sub': [0.3]}})
calc = Calculator(policy=Policy(),
records=Records.cps_constructor(data=cps_subsample),
behavior=behv)
beh.update_behavior({2013: {'_BE_sub': [0.3]}})
calc = Calculator(policy=pol, records=rec, behavior=beh)
assert calc.behavior.has_response()
adt = multiyear_diagnostic_table(calc, 3)
assert isinstance(adt, pd.DataFrame)


def test_myr_diag_table_wo_behv(cps_subsample):
pol = Policy()
reform = {
2013: {
'_II_rt7': [0.33],
'_PT_rt7': [0.33],
}}
pol = Policy()
pol.implement_reform(reform)
calc = Calculator(policy=pol,
records=Records.cps_constructor(data=cps_subsample))
Expand Down Expand Up @@ -794,19 +788,17 @@ def test_ce_aftertax_income(cps_subsample):
cmin = 1000
assert con == round(certainty_equivalent(con, 0, cmin), 6)
# test with require_no_agg_tax_change equal to False
rec = Records.cps_constructor(data=cps_subsample)
cyr = 2020
# specify calc1 and calc_all() for cyr
pol1 = Policy()
rec1 = Records.cps_constructor(data=cps_subsample)
calc1 = Calculator(policy=pol1, records=rec1)
pol = Policy()
calc1 = Calculator(policy=pol, records=rec)
calc1.advance_to_year(cyr)
calc1.calc_all()
# specify calc2 and calc_all() for cyr
pol2 = Policy()
reform = {2018: {'_II_em': [0.0]}}
pol2.implement_reform(reform)
rec2 = Records.cps_constructor(data=cps_subsample)
calc2 = Calculator(policy=pol2, records=rec2)
pol.implement_reform(reform)
calc2 = Calculator(policy=pol, records=rec)
calc2.advance_to_year(cyr)
calc2.calc_all()
cedict = ce_aftertax_income(calc1, calc2, require_no_agg_tax_change=False)
Expand Down

0 comments on commit 30c4567

Please sign in to comment.