Skip to content

Commit

Permalink
Merge pull request #2809 from martinholmer/revise-tmd-ctor
Browse files Browse the repository at this point in the history
Generalize Records.tmd_constructor static method
  • Loading branch information
martinholmer authored Sep 24, 2024
2 parents 8504891 + b0d0358 commit 0936659
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions taxcalc/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pylint --disable=locally-disabled records.py

import os
from pathlib import Path
import numpy as np
import pandas as pd
from taxcalc.data import Data
Expand Down Expand Up @@ -116,9 +117,6 @@ class instance: Records
PUF_RATIOS_FILENAME = 'puf_ratios.csv'
CPS_WEIGHTS_FILENAME = 'cps_weights.csv.gz'
CPS_RATIOS_FILENAME = None
TMD_WEIGHTS_FILENAME = 'tmd_weights.csv.gz'
TMD_GROWFACTORS_FILENAME = 'tmd_growfactors.csv'
TMD_RATIOS_FILENAME = None
CODE_PATH = os.path.abspath(os.path.dirname(__file__))
VARINFO_FILE_NAME = 'records_variables.json'
VARINFO_FILE_PATH = CODE_PATH
Expand Down Expand Up @@ -226,9 +224,12 @@ def cps_constructor(data=None,
exact_calculations=exact_calculations)

@staticmethod
def tmd_constructor(data, # path to tmd.csv file or dataframe
gfactors=GrowFactors(TMD_GROWFACTORS_FILENAME),
exact_calculations=False): # pragma: no cover
def tmd_constructor(
data_path: Path,
weights_path: Path,
growfactors_path: Path,
exact_calculations=False
): # pragma: no cover
"""
Static method returns a Records object instantiated with TMD
input data. This works in a analogous way to Records(), which
Expand All @@ -239,14 +240,15 @@ def tmd_constructor(data, # path to tmd.csv file or dataframe
eliminate the need to specify all the details of the PUF input
data.
"""
weights = os.path.join(Records.CODE_PATH, Records.TMD_WEIGHTS_FILENAME)
return Records(data=data,
start_year=Records.TMDCSV_YEAR,
gfactors=gfactors,
weights=weights,
adjust_ratios=Records.TMD_RATIOS_FILENAME,
exact_calculations=exact_calculations)

return Records(
data=pd.read_csv(data_path),
start_year=Records.TMDCSV_YEAR,
weights=str(weights_path),
gfactors=GrowFactors(growfactors_filename=str(growfactors_path)),
adjust_ratios=None,
exact_calculations=exact_calculations,
)

def increment_year(self):
"""
Add one to current year, and also does
Expand Down Expand Up @@ -277,7 +279,7 @@ def _extrapolate(self, year):
"""
# pylint: disable=too-many-statements,no-member
# put values in local dictionary
gfv = dict()
gfv = {}
for name in GrowFactors.VALID_NAMES:
gfv[name] = self.gfactors.factor_value(name, year)
# apply values to Records variables
Expand Down

0 comments on commit 0936659

Please sign in to comment.