Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor fuzzing logic #1979

Merged
merged 4 commits into from
Apr 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions taxcalc/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,9 @@ def have_same_income_measure(calc1, calc2, income_measure):
income_measure == 'c00100')
assert (result_type == 'weighted_sum' or
result_type == 'weighted_avg')
if calc is not None:
assert np.allclose(self.array('s006'),
calc.array('s006')) # check rows in same order
var_dataframe = self.distribution_table_dataframe()
dt1 = create_distribution_table(var_dataframe,
groupby=groupby,
Expand Down
79 changes: 54 additions & 25 deletions taxcalc/tbi/tbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from taxcalc.tbi.tbi_utils import (check_years_return_first_year,
calculate,
random_seed,
summary,
fuzzed,
summary_aggregate,
summary_dist_xbin, summary_diff_xbin,
summary_dist_xdec, summary_diff_xdec,
create_dict_table,
AGGR_ROW_NAMES)
from taxcalc import (DIST_TABLE_LABELS, DIFF_TABLE_LABELS,
Expand Down Expand Up @@ -96,35 +99,61 @@ def run_nth_year_tax_calc_model(year_n, start_year,
Setting use_full_sample=False implies use sub-sample of input file;
otherwsie, use the complete sample.
"""
# pylint: disable=too-many-arguments,too-many-locals
# pylint: disable=too-many-arguments,too-many-locals,too-many-branches

start_time = time.time()

# create calc1 and calc2 calculated for year_n
check_years_return_first_year(year_n, start_year, use_puf_not_cps)
(calc1, calc2) = calculate(year_n, start_year,
use_puf_not_cps, use_full_sample,
user_mods,
behavior_allowed=True)
calc1, calc2 = calculate(year_n, start_year,
use_puf_not_cps, use_full_sample,
user_mods,
behavior_allowed=True)

# extract raw results from calc1 and calc2
rawres1 = calc1.distribution_table_dataframe()
rawres2 = calc2.distribution_table_dataframe()
# extract unfuzzed raw results from calc1 and calc2
dv1 = calc1.distribution_table_dataframe()
dv2 = calc2.distribution_table_dataframe()

# delete calc1 and calc2 now that raw results have been extracted
del calc1
del calc2

# seed random number generator with a seed value based on user_mods
seed = random_seed(user_mods)
print('seed={}'.format(seed))
np.random.seed(seed) # pylint: disable=no-member

# construct TaxBrain summary results from raw results
summ = summary(rawres1, rawres2, use_puf_not_cps)
del rawres1
del rawres2
sres = dict()
fuzzing = use_puf_not_cps
if fuzzing:
# seed random number generator with a seed value based on user_mods
# (reform-specific seed is used to choose whose results are fuzzed)
seed = random_seed(user_mods)
print('fuzzing_seed={}'.format(seed))
np.random.seed(seed) # pylint: disable=no-member
# make bool array marking which filing units are affected by the reform
reform_affected = np.logical_not( # pylint: disable=no-member
np.isclose(dv1['combined'], dv2['combined'], atol=0.01, rtol=0.0)
)
agg1, agg2 = fuzzed(dv1, dv2, reform_affected, 'aggr')
sres = summary_aggregate(sres, agg1, agg2)
del agg1
del agg2
dv1b, dv2b = fuzzed(dv1, dv2, reform_affected, 'xbin')
sres = summary_dist_xbin(sres, dv1b, dv2b)
sres = summary_diff_xbin(sres, dv1b, dv2b)
del dv1b
del dv2b
dv1d, dv2d = fuzzed(dv1, dv2, reform_affected, 'xdec')
sres = summary_dist_xdec(sres, dv1d, dv2d)
sres = summary_diff_xdec(sres, dv1d, dv2d)
del dv1d
del dv2d
del reform_affected
else:
sres = summary_aggregate(sres, dv1, dv2)
sres = summary_dist_xbin(sres, dv1, dv2)
sres = summary_diff_xbin(sres, dv1, dv2)
sres = summary_dist_xdec(sres, dv1, dv2)
sres = summary_diff_xdec(sres, dv1, dv2)

# nested function used below
def append_year(pdf):
"""
append_year embedded function revises all column names in pdf
Expand All @@ -135,22 +164,22 @@ def append_year(pdf):
# optionally return non-JSON-like results
if not return_dict:
res = dict()
for tbl in summ:
res[tbl] = append_year(summ[tbl])
for tbl in sres:
res[tbl] = append_year(sres[tbl])
elapsed_time = time.time() - start_time
print('elapsed time for this run: {:.1f}'.format(elapsed_time))
return res

# optionally construct JSON-like results dictionaries for year n
dec_rownames = list(summ['diff_comb_xdec'].index.values)
dec_rownames = list(sres['diff_comb_xdec'].index.values)
dec_row_names_n = [x + '_' + str(year_n) for x in dec_rownames]
bin_rownames = list(summ['diff_comb_xbin'].index.values)
bin_rownames = list(sres['diff_comb_xbin'].index.values)
bin_row_names_n = [x + '_' + str(year_n) for x in bin_rownames]
agg_row_names_n = [x + '_' + str(year_n) for x in AGG_ROW_NAMES]
dist_column_types = [float] * len(DIST_TABLE_LABELS)
diff_column_types = [float] * len(DIFF_TABLE_LABELS)
info = dict()
for tbl in summ:
for tbl in sres:
info[tbl] = {'row_names': [], 'col_types': []}
if 'dec' in tbl:
info[tbl]['row_names'] = dec_row_names_n
Expand All @@ -163,13 +192,13 @@ def append_year(pdf):
elif 'diff' in tbl:
info[tbl]['col_types'] = diff_column_types
res = dict()
for tbl in summ:
for tbl in sres:
if 'aggr' in tbl:
res_table = create_dict_table(summ[tbl],
res_table = create_dict_table(sres[tbl],
row_names=info[tbl]['row_names'])
res[tbl] = dict((k, v[0]) for k, v in res_table.items())
else:
res[tbl] = create_dict_table(summ[tbl],
res[tbl] = create_dict_table(sres[tbl],
row_names=info[tbl]['row_names'],
column_types=info[tbl]['col_types'])

Expand Down
Loading