Skip to content

Commit

Permalink
Grab 'values' attribute from Series in iterate_jit
Browse files Browse the repository at this point in the history
 - Resolves PSLmodels#1468
 - if attribute passed to 'high level' function is a Series,
   grab the underlying 'values' attribute, which is enough like an
   ndarray to satisfy Numba (so that the jit operation can succeed)
  • Loading branch information
T.J. Alumbaugh committed Jul 12, 2017
1 parent eab6cfa commit 4c88306
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
9 changes: 8 additions & 1 deletion taxcalc/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,22 @@ def hl_func(x_0, x_1, x_2, ...):
fstr.write("):\n")
fstr.write(" from pandas import DataFrame\n")
fstr.write(" import numpy as np\n")
fstr.write(" import pandas as pd\n")
fstr.write(" def get_values(x):\n")
fstr.write(" if isinstance(x, pd.Series):\n")
fstr.write(" return x.values\n")
fstr.write(" else:\n")
fstr.write(" return x\n")
fstr.write(" outputs = \\\n")

outs = []
for ppp, attr in zip(pm_or_pf, args_out + args_in):
outs.append(ppp + "." + attr + ", ")
outs = [m_or_f + "." + arg for m_or_f, arg in zip(pm_or_pf, args_out)]
fstr.write(" (" + ", ".join(outs) + ") = \\\n")
fstr.write(" " + "applied_f(")
for ppp, attr in zip(pm_or_pf, args_out + args_in):
fstr.write(ppp + "." + attr + ", ")
fstr.write("get_values(" + ppp + "." + attr + ")" + ", ")
fstr.write(")\n")
fstr.write(" header = [")
col_headers = ["'" + out + "'" for out in args_out]
Expand Down
17 changes: 15 additions & 2 deletions taxcalc/tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,16 @@ def test_create_toplevel_function_string_mult_outputs():
exp = ("def hl_func(pm, pf):\n"
" from pandas import DataFrame\n"
" import numpy as np\n"
" import pandas as pd\n"
" def get_values(x):\n"
" if isinstance(x, pd.Series):\n"
" return x.values\n"
" else:\n"
" return x\n"
" outputs = \\\n"
" (pm.a, pm.b) = \\\n"
" applied_f(pm.a, pm.b, pf.d, pm.e, )\n"
" applied_f(get_values(pm.a), get_values(pm.b), "
"get_values(pf.d), get_values(pm.e), )\n"
" header = ['a', 'b']\n"
" return DataFrame(data=np.column_stack(outputs),"
"columns=header)")
Expand All @@ -49,9 +56,15 @@ def test_create_toplevel_function_string():
exp = ("def hl_func(pm, pf):\n"
" from pandas import DataFrame\n"
" import numpy as np\n"
" import pandas as pd\n"
" def get_values(x):\n"
" if isinstance(x, pd.Series):\n"
" return x.values\n"
" else:\n"
" return x\n"
" outputs = \\\n"
" (pm.a) = \\\n"
" applied_f(pm.a, pf.d, pm.e, )\n"
" applied_f(get_values(pm.a), get_values(pf.d), get_values(pm.e), )\n"
" header = ['a']\n"
" return DataFrame(data=outputs,"
"columns=header)")
Expand Down

0 comments on commit 4c88306

Please sign in to comment.