Skip to content

Commit

Permalink
Allow Parameterized.param.update to be used as a context manager fo…
Browse files Browse the repository at this point in the history
…r temporary updates (#779)

Co-authored-by: maximlt <mliquet@anaconda.com>
  • Loading branch information
philippjfr and maximlt authored Jul 11, 2023
1 parent 6692574 commit 6b50168
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 25 deletions.
73 changes: 69 additions & 4 deletions examples/user_guide/Parameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -902,14 +902,79 @@
"id": "23979d82",
"metadata": {},
"source": [
"# Other Parameterized methods\n",
"## Other Parameterized methods\n",
"\n",
"Like `.param.pprint`, the remaining \"utility\" or convenience methods available for a `Parameterized` class or object are provided via a subobject called `param` that helps keep the namespace clean and disambiguate between Parameter objects and parameter values:\n",
"\n",
"- `.param.update(**kwargs)`: Set parameter values from the given `param=value` keyword arguments (or a dict or iterable), delaying watching and dependency handling until all have been updated. `.param.update` can also be used as a context manager to temporarily set values, that are restored to their original values when the context manager exits."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1779f544",
"metadata": {},
"outputs": [],
"source": [
"p.param.update(a=0, b='start');\n",
"print(p.a, p.b)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1508201",
"metadata": {},
"outputs": [],
"source": [
"with p.param.update(a=1, b='temp'):\n",
" print(f'In the context manager: {p.a=}, {p.b=}')\n",
"print(f'After the context manager exits: {p.a=}, {p.b=}')"
]
},
{
"cell_type": "markdown",
"id": "0b9e1d85",
"metadata": {},
"source": [
"- `.param.values(onlychanged=False)`: A dict of name,value pairs for all parameters of this object"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd7f0eca",
"metadata": {},
"outputs": [],
"source": [
"p.param.values()"
]
},
{
"cell_type": "markdown",
"id": "244a17d0",
"metadata": {},
"source": [
"- `.param.objects(instance=True)`: Parameter objects of this instance or class"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1baf4823",
"metadata": {},
"outputs": [],
"source": [
"p.param.objects()"
]
},
{
"cell_type": "markdown",
"id": "1c5085ea",
"metadata": {},
"source": [
"\n",
"- `.param.add_parameter(param_name,param_obj)`: Dynamically add a new Parameter to this object's class\n",
"- `.param.update(**kwargs)`: Set parameter values from the given `param=value` keyword arguments (or a dict or iterable), delaying watching and dependency handling until all have been updated\n",
"- `.param.values(onlychanged=False)`: A dict of name,value pairs for all parameters of this object\n",
"- `.param.objects(instance=True)`: Parameter objects of this instance or class\n",
"- `.param.get_value_generator(name)`: Returns the underlying value-generating callable for this parameter, or the underlying static value if none\n",
"- `.param.force_new_dynamic_value(name)`: For a Dynamic parameter, generate a new value and return it\n",
"- `.param.inspect_value(name)`: For a Dynamic parameter, return the current value of the named attribute without modifying it.\n"
Expand Down
71 changes: 50 additions & 21 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,25 @@ def compare_mapping(cls, obj1, obj2):
return True


class _ParametersRestorer:
"""
Context-manager to handle the reset of parameter values after an update.
"""

def __init__(self, *, parameters, restore):
self._parameters = parameters
self._restore = restore

def __enter__(self):
return self._restore

def __exit__(self, exc_type, exc_value, exc_tb):
try:
self._parameters._update(self._restore)
finally:
self._restore = {}


class Parameters:
"""Object that holds the namespace and implementation of Parameterized
methods as well as any state that is not in __slots__ or the
Expand Down Expand Up @@ -1778,19 +1797,16 @@ def __getattr__(self_, attr):
else:
raise AttributeError(f"'{self_.cls.__name__}.param' object has no attribute {attr!r}")


@as_uninitialized
def _set_name(self_, name):
self = self_.param.self
self.name=name


@as_uninitialized
def _generate_name(self_):
self = self_.param.self
self.param._set_name('%s%05d' % (self.__class__.__name__ ,object_count))


@as_uninitialized
def _setup_params(self_,**params):
"""
Expand Down Expand Up @@ -2003,7 +2019,6 @@ def set_default(self_,param_name,value):
cls = self_.cls
setattr(cls,param_name,value)


def add_parameter(self_, param_name, param_obj):
"""
Add a new Parameter object into this object's class.
Expand Down Expand Up @@ -2056,46 +2071,61 @@ def params(self_, parameter_name=None):

def update(self_, *args, **kwargs):
"""
For the given dictionary or iterable or set of param=value keyword arguments,
sets the corresponding parameter of this object or class to the given value.
For the given dictionary or iterable or set of param=value
keyword arguments, sets the corresponding parameter of this
object or class to the given value.
May also be used as a context manager to temporarily set and
then reset parameter values.
"""
BATCH_WATCH = self_.self_or_cls.param._BATCH_WATCH
self_.self_or_cls.param._BATCH_WATCH = True
restore = self_._update(*args, **kwargs)
return _ParametersRestorer(parameters=self_, restore=restore)

def _update(self_, *args, **kwargs):
BATCH_WATCH = self_._BATCH_WATCH
self_._BATCH_WATCH = True
self_or_cls = self_.self_or_cls
if args:
if len(args) == 1 and not kwargs:
kwargs = args[0]
else:
self_.self_or_cls.param._BATCH_WATCH = False
raise ValueError("%s.update accepts *either* an iterable or key=value pairs, not both" %
(self_or_cls.name))
self_._BATCH_WATCH = False
raise ValueError(
f"{self_.cls.__name__}.param.update accepts *either* an iterable "
"or key=value pairs, not both."
)

trigger_params = [k for k in kwargs
if ((k in self_.self_or_cls.param) and
hasattr(self_.self_or_cls.param[k], '_autotrigger_value'))]
trigger_params = [
k for k in kwargs
if k in self_ and hasattr(self_[k], '_autotrigger_value')
]

for tp in trigger_params:
self_.self_or_cls.param[tp]._mode = 'set'

values = self_.values()
restore = {k: values[k] for k, v in kwargs.items() if k in values}

for (k, v) in kwargs.items():
if k not in self_or_cls.param:
self_.self_or_cls.param._BATCH_WATCH = False
raise ValueError(f"'{k}' is not a parameter of {self_or_cls.name}")
if k not in self_:
self_._BATCH_WATCH = False
raise ValueError(f"{k!r} is not a parameter of {self_.cls.__name__}")
try:
setattr(self_or_cls, k, v)
except:
self_.self_or_cls.param._BATCH_WATCH = False
self_._BATCH_WATCH = False
raise

self_.self_or_cls.param._BATCH_WATCH = BATCH_WATCH
self_._BATCH_WATCH = BATCH_WATCH
if not BATCH_WATCH:
self_._batch_call_watchers()

for tp in trigger_params:
p = self_.self_or_cls.param[tp]
p = self_[tp]
p._mode = 'reset'
setattr(self_or_cls, tp, p._autotrigger_reset_value)
p._mode = 'set-reset'
return restore

# PARAM3_DEPRECATION
@_deprecated(extra_msg="Use instead `.param.update`")
Expand All @@ -2121,7 +2151,6 @@ def set_param(self_, *args,**kwargs):
(self_or_cls.name))
return self_.update(kwargs)


def objects(self_, instance=True):
"""
Returns the Parameters of this instance or class
Expand Down
168 changes: 168 additions & 0 deletions tests/testparameterizedobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,174 @@ def test_error_if_non_param_in_constructor(self):
with pytest.raises(TypeError, match=re.escape(msg)):
TestPO(not_a_param=2)

def test_update_class(self):
class P(param.Parameterized):
x = param.Parameter()

p = P()

P.param.update(x=10)

assert P.x == p.x == 10

def test_update_context_class(self):
class P(param.Parameterized):
x = param.Parameter(10)

p = P()

with P.param.update(x=20):
assert P.x == p.x == 20

assert P.x == p.x == 10

def test_update_class_watcher(self):
class P(param.Parameterized):
x = param.Parameter()

events = []
P.param.watch(events.append, 'x')

P.param.update(x=10)

assert len(events) == 1
assert events[0].name == 'x' and events[0].new == 10

def test_update_context_class_watcher(self):
class P(param.Parameterized):
x = param.Parameter(0)

events = []
P.param.watch(events.append, 'x')

with P.param.update(x=20):
pass

assert len(events) == 2
assert events[0].name == 'x' and events[0].new == 20
assert events[1].name == 'x' and events[1].new == 0

def test_update_instance_watcher(self):
class P(param.Parameterized):
x = param.Parameter()

p = P()

events = []
p.param.watch(events.append, 'x')

p.param.update(x=10)

assert len(events) == 1
assert events[0].name == 'x' and events[0].new == 10

def test_update_context_instance_watcher(self):
class P(param.Parameterized):
x = param.Parameter(0)

p = P()

events = []
p.param.watch(events.append, 'x')

with p.param.update(x=20):
pass

assert len(events) == 2
assert events[0].name == 'x' and events[0].new == 20
assert events[1].name == 'x' and events[1].new == 0

def test_update_error_not_param_class(self):
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
TestPO.param.update(not_a_param=1)

def test_update_error_not_param_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
t.param.update(not_a_param=1)

def test_update_context_error_not_param_class(self):
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
with TestPO.param.update(not_a_param=1):
pass

def test_update_context_error_not_param_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
with t.param.update(not_a_param=1):
pass

def test_update_error_while_updating(self):
class P(param.Parameterized):
x = param.Parameter(0, readonly=True)

with pytest.raises(TypeError):
P.param.update(x=1)

assert P.x == 0

with pytest.raises(TypeError):
with P.param.update(x=1):
pass

assert P.x == 0

p = P()

with pytest.raises(TypeError):
p.param.update(x=1)

assert p.x == 0

with pytest.raises(TypeError):
with p.param.update(x=1):
pass

assert p.x == 0

def test_update_error_dict_and_kwargs_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
t.param.update(dict(a=1), a=1)

def test_update_context_error_dict_and_kwargs_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
with t.param.update(dict(a=1), a=1):
pass

def test_update_error_dict_and_kwargs_class(self):
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
TestPO.param.update(dict(a=1), a=1)

def test_update_context_error_dict_and_kwargs_class(self):
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
with TestPO.param.update(dict(a=1), a=1):
pass

def test_update_context_single_parameter(self):
t = TestPO(inst='foo')
with t.param.update(inst='bar'):
assert t.inst == 'bar'
assert t.inst == 'foo'

def test_update_context_does_not_set_other_params(self):
t = TestPO(inst='foo')
events = []
t.param.watch(events.append, list(t.param), onlychanged=False)
with t.param.update(inst='bar'):
pass
assert len(events) == 2
assert all(e.name == 'inst' for e in events)

def test_update_context_multi_parameter(self):
t = TestPO(inst='foo', notinst=1)
with t.param.update(inst='bar', notinst=2):
assert t.inst == 'bar'
assert t.notinst == 2
assert t.inst == 'foo'
assert t.notinst == 1


class some_fn(param.ParameterizedFunction):
__test__ = False
Expand Down

0 comments on commit 6b50168

Please sign in to comment.