Skip to content

Commit

Permalink
Merge pull request #1441 from martinholmer/fix-sampling
Browse files Browse the repository at this point in the history
Improve sample weight adjustment
  • Loading branch information
martinholmer authored Jun 27, 2017
2 parents 0f46a04 + 52242f5 commit cf86e07
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
6 changes: 5 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ Release 0.9.1 on 2017-??-??
[#XXXX](https://github.com/open-source-economics/Tax-Calculator/pull/XXXX))

**API Changes**
- None

**New Features**
- Improve calculation of sub-sample weights
[[#1441](https://github.com/open-source-economics/Tax-Calculator/pull/1441)
by Hank Doupe]

**Bug Fixes**

- None

Release 0.9.0 on 2017-06-14
---------------------------
Expand Down
8 changes: 6 additions & 2 deletions taxcalc/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pep8 --ignore=E402 records.py
# pylint --disable=locally-disabled records.py

from __future__ import division
import os
import json
import six
Expand Down Expand Up @@ -146,9 +147,12 @@ def __init__(self,
self._read_adjust(adjust_ratios)
# weights must be same size as tax record data
if not self.WT.empty and self.dim != len(self.WT):
frac = float(self.dim) / len(self.WT)
# scale-up sub-sample weights by year-specific factor
sum_full_weights = self.WT.sum()
self.WT = self.WT.iloc[self.index]
self.WT = self.WT / frac
sum_sub_weights = self.WT.sum()
factor = sum_full_weights / sum_sub_weights
self.WT = self.WT * factor
# specify current_year and FLPDYR values
if isinstance(start_year, int):
self._current_year = start_year
Expand Down
17 changes: 9 additions & 8 deletions taxcalc/tests/test_pufcsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def puf_path(tests_path):
def test_agg(tests_path, puf_path): # pylint: disable=redefined-outer-name
"""
Test Tax-Calculator aggregate taxes with no policy reform using
the full-sample puf.csv and a two-percent sub-sample of puf.csv
the full-sample puf.csv and a small sub-sample of puf.csv
"""
# pylint: disable=too-many-locals,too-many-statements
nyrs = 10
Expand Down Expand Up @@ -78,23 +78,24 @@ def test_agg(tests_path, puf_path): # pylint: disable=redefined-outer-name
msg += '--- and rerun test. ---\n'
msg += '-------------------------------------------------\n'
raise ValueError(msg)
# create aggregate diagnostic table using sub sample of records
# create aggregate diagnostic table using unweighted sub-sample of records
fullsample = pd.read_csv(puf_path)
rn_seed = 80 # to ensure sub-sample is always the same
subfrac = 0.02 # sub-sample fraction
rn_seed = 180 # to ensure sub-sample is always the same
subfrac = 0.05 # sub-sample fraction
subsample = fullsample.sample(frac=subfrac, # pylint: disable=no-member
random_state=rn_seed)
rec_subsample = Records(data=subsample)
calc_subsample = Calculator(policy=Policy(), records=rec_subsample)
adt_subsample = multiyear_diagnostic_table(calc_subsample, num_years=nyrs)
# compare combined tax liability from full and sub samples for each year
taxes_subsample = adt_subsample.loc["Combined Liability ($b)"]
reltol = 0.04 # maximum allowed relative difference in tax liability
reltol = 0.01 # maximum allowed relative difference in tax liability
if not np.allclose(taxes_subsample, taxes_fullsample,
atol=0.0, rtol=reltol):
msg = 'PUFCSV AGG RESULTS DIFFER IN SUB-SAMPLE AND FULL-SAMPLE\n'
msg += 'WHEN subfrac = {:.3f} and reltol = {:.4f}\n'.format(subfrac,
reltol)
msg += 'WHEN subfrac={:.3f}, rtol={:.4f}, seed={}\n'.format(subfrac,
reltol,
rn_seed)
it_sub = np.nditer(taxes_subsample, flags=['f_index'])
it_all = np.nditer(taxes_fullsample, flags=['f_index'])
while not it_sub.finished:
Expand All @@ -103,7 +104,7 @@ def test_agg(tests_path, puf_path): # pylint: disable=redefined-outer-name
tax_all = float(it_all[0])
reldiff = abs(tax_sub - tax_all) / abs(tax_all)
if reldiff > reltol:
msgstr = ' year,sub,full,reldif= {}\t{:.2f}\t{:.2f}\t{:.4f}\n'
msgstr = ' year,sub,full,reldiff= {}\t{:.2f}\t{:.2f}\t{:.4f}\n'
msg += msgstr.format(cyr, tax_sub, tax_all, reldiff)
it_sub.iternext()
it_all.iternext()
Expand Down

0 comments on commit cf86e07

Please sign in to comment.