From 4c88306e8bdf1521b445acf30fb8c95c76a3eaa5 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Tue, 11 Jul 2017 21:39:36 -0700 Subject: [PATCH] Grab 'values' attribute from Series in iterate_jit - Resolves #1468 - if attribute passed to 'high level' function is a Series, grab the underlying 'values' attribute, which is enough like an ndarray to satisfy Numba (so that the jit operation can succeed) --- taxcalc/decorators.py | 9 ++++++++- taxcalc/tests/test_decorators.py | 17 +++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/taxcalc/decorators.py b/taxcalc/decorators.py index 342c78b6c..b86b8aac9 100644 --- a/taxcalc/decorators.py +++ b/taxcalc/decorators.py @@ -127,7 +127,14 @@ def hl_func(x_0, x_1, x_2, ...): fstr.write("):\n") fstr.write(" from pandas import DataFrame\n") fstr.write(" import numpy as np\n") + fstr.write(" import pandas as pd\n") + fstr.write(" def get_values(x):\n") + fstr.write(" if isinstance(x, pd.Series):\n") + fstr.write(" return x.values\n") + fstr.write(" else:\n") + fstr.write(" return x\n") fstr.write(" outputs = \\\n") + outs = [] for ppp, attr in zip(pm_or_pf, args_out + args_in): outs.append(ppp + "." + attr + ", ") @@ -135,7 +142,7 @@ def hl_func(x_0, x_1, x_2, ...): fstr.write(" (" + ", ".join(outs) + ") = \\\n") fstr.write(" " + "applied_f(") for ppp, attr in zip(pm_or_pf, args_out + args_in): - fstr.write(ppp + "." + attr + ", ") + fstr.write("get_values(" + ppp + "." + attr + ")" + ", ") fstr.write(")\n") fstr.write(" header = [") col_headers = ["'" + out + "'" for out in args_out] diff --git a/taxcalc/tests/test_decorators.py b/taxcalc/tests/test_decorators.py index 4eacad7d0..d9740c628 100644 --- a/taxcalc/tests/test_decorators.py +++ b/taxcalc/tests/test_decorators.py @@ -32,9 +32,16 @@ def test_create_toplevel_function_string_mult_outputs(): exp = ("def hl_func(pm, pf):\n" " from pandas import DataFrame\n" " import numpy as np\n" + " import pandas as pd\n" + " def get_values(x):\n" + " if isinstance(x, pd.Series):\n" + " return x.values\n" + " else:\n" + " return x\n" " outputs = \\\n" " (pm.a, pm.b) = \\\n" - " applied_f(pm.a, pm.b, pf.d, pm.e, )\n" + " applied_f(get_values(pm.a), get_values(pm.b), " + "get_values(pf.d), get_values(pm.e), )\n" " header = ['a', 'b']\n" " return DataFrame(data=np.column_stack(outputs)," "columns=header)") @@ -49,9 +56,15 @@ def test_create_toplevel_function_string(): exp = ("def hl_func(pm, pf):\n" " from pandas import DataFrame\n" " import numpy as np\n" + " import pandas as pd\n" + " def get_values(x):\n" + " if isinstance(x, pd.Series):\n" + " return x.values\n" + " else:\n" + " return x\n" " outputs = \\\n" " (pm.a) = \\\n" - " applied_f(pm.a, pf.d, pm.e, )\n" + " applied_f(get_values(pm.a), get_values(pf.d), get_values(pm.e), )\n" " header = ['a']\n" " return DataFrame(data=outputs," "columns=header)")