Skip to content

Commit

Permalink
Add Policy.scan_param_code function to check for dangerous code.
Browse files Browse the repository at this point in the history
  • Loading branch information
martinholmer committed Dec 2, 2016
1 parent 435d222 commit e545fa9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
8 changes: 4 additions & 4 deletions taxcalc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def ALD_Investment_ec_base_code_function(calc):
"""
code = calc.policy.param_code['ALD_Investment_ec_base_code']
visible = {'min': np.minimum, 'max': np.maximum, 'where': np.where}
vars = ['e00300', 'e00600', 'e00650', 'e01100', 'e01200',
'p22250', 'p23250', '_sep']
for var in vars:
variables = ['e00300', 'e00600', 'e00650', 'e01100', 'e01200',
'p22250', 'p23250', '_sep']
for var in variables:
visible[var] = getattr(calc.records, var)
# pylint: disable=eval-used
calc.records.investment_ec_base = eval(compile(code, '<str>', 'eval'),
{'__builtins__': None}, visible)
{'__builtins__': {}}, visible)


@iterate_jit(nopython=True)
Expand Down
23 changes: 21 additions & 2 deletions taxcalc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,9 @@ def implement_reform(self, reform):
param_code_dict = reform.pop(zero, None)
if param_code_dict:
reform_years.remove(zero)
for param in param_code_dict.keys():
self.param_code[param] = param_code_dict[param]
for param, code in param_code_dict.items():
Policy.scan_param_code(code)
self.param_code[param] = code
# check range of remaining reform_years
first_reform_year = min(reform_years)
if first_reform_year < self.start_year:
Expand All @@ -346,6 +347,24 @@ def implement_reform(self, reform):
self._update({year: reform[year]})
self.set_year(precall_current_year)

@staticmethod
def scan_param_code(code):
"""
Raise ValueError if certain character strings found in specified code.
"""
if re.search(r'__', code) is not None:
msg = 'Following param_code includes illegal "__":\n'
msg += code
raise ValueError(msg)
if re.search(r'lambda', code) is not None:
msg = 'Following param_code includes illegal "lambda":\n'
msg += code
raise ValueError(msg)
if re.search(r'\[', code) is not None:
msg = 'Following param_code includes illegal "[":\n'
msg += code
raise ValueError(msg)

@staticmethod
def convert_reform_dictionary(param_key_dict):
"""
Expand Down
16 changes: 14 additions & 2 deletions taxcalc/tests/test_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import sys
import json
import tempfile
import numpy as np
from numpy.testing import assert_allclose
import pytest
import tempfile
from taxcalc import Policy


Expand Down Expand Up @@ -725,6 +725,18 @@ def test_current_law_version():
mte = clv._SS_Earnings_c
clv_mte_2015 = mte[2015 - syr]
clv_mte_2016 = mte[2016 - syr]
assert (clp_mte_2015 == ref_mte_2015 == clv_mte_2015)
assert clp_mte_2015 == ref_mte_2015 == clv_mte_2015
assert clp_mte_2016 != ref_mte_2016
assert clp_mte_2016 == clv_mte_2016


def test_scan_param_code():
"""
Test scan_param_code function.
"""
with pytest.raises(ValueError):
Policy.scan_param_code('__builtins__')
with pytest.raises(ValueError):
Policy.scan_param_code('lambda x: x**2')
with pytest.raises(ValueError):
Policy.scan_param_code('[x**2 for x in range(9)]')

0 comments on commit e545fa9

Please sign in to comment.