Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Parameterized.param.update to be used as a context manager for temporary updates #779

Merged
merged 12 commits into from
Jul 11, 2023
73 changes: 69 additions & 4 deletions examples/user_guide/Parameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -902,14 +902,79 @@
"id": "23979d82",
"metadata": {},
"source": [
"# Other Parameterized methods\n",
"## Other Parameterized methods\n",
"\n",
"Like `.param.pprint`, the remaining \"utility\" or convenience methods available for a `Parameterized` class or object are provided via a subobject called `param` that helps keep the namespace clean and disambiguate between Parameter objects and parameter values:\n",
"\n",
"- `.param.update(**kwargs)`: Set parameter values from the given `param=value` keyword arguments (or a dict or iterable), delaying watching and dependency handling until all have been updated. `.param.update` can also be used as a context manager to temporarily set values, that are restored to their original values when the context manager exits."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1779f544",
"metadata": {},
"outputs": [],
"source": [
"p.param.update(a=0, b='start');\n",
"print(p.a, p.b)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1508201",
"metadata": {},
"outputs": [],
"source": [
"with p.param.update(a=1, b='temp'):\n",
" print(f'In the context manager: {p.a=}, {p.b=}')\n",
"print(f'After the context manager exits: {p.a=}, {p.b=}')"
]
},
{
"cell_type": "markdown",
"id": "0b9e1d85",
"metadata": {},
"source": [
"- `.param.values(onlychanged=False)`: A dict of name,value pairs for all parameters of this object"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd7f0eca",
"metadata": {},
"outputs": [],
"source": [
"p.param.values()"
]
},
{
"cell_type": "markdown",
"id": "244a17d0",
"metadata": {},
"source": [
"- `.param.objects(instance=True)`: Parameter objects of this instance or class"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1baf4823",
"metadata": {},
"outputs": [],
"source": [
"p.param.objects()"
]
},
{
"cell_type": "markdown",
"id": "1c5085ea",
"metadata": {},
"source": [
"\n",
"- `.param.add_parameter(param_name,param_obj)`: Dynamically add a new Parameter to this object's class\n",
"- `.param.update(**kwargs)`: Set parameter values from the given `param=value` keyword arguments (or a dict or iterable), delaying watching and dependency handling until all have been updated\n",
"- `.param.values(onlychanged=False)`: A dict of name,value pairs for all parameters of this object\n",
"- `.param.objects(instance=True)`: Parameter objects of this instance or class\n",
"- `.param.get_value_generator(name)`: Returns the underlying value-generating callable for this parameter, or the underlying static value if none\n",
"- `.param.force_new_dynamic_value(name)`: For a Dynamic parameter, generate a new value and return it\n",
"- `.param.inspect_value(name)`: For a Dynamic parameter, return the current value of the named attribute without modifying it.\n"
Expand Down
71 changes: 50 additions & 21 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,25 @@ def compare_mapping(cls, obj1, obj2):
return True


class _ParametersRestorer:
"""
Context-manager to handle the reset of parameter values after an update.
"""

def __init__(self, *, parameters, restore):
self._parameters = parameters
self._restore = restore

def __enter__(self):
return self._restore

def __exit__(self, exc_type, exc_value, exc_tb):
try:
self._parameters._update(self._restore)
finally:
self._restore = {}


class Parameters:
"""Object that holds the namespace and implementation of Parameterized
methods as well as any state that is not in __slots__ or the
Expand Down Expand Up @@ -1778,19 +1797,16 @@ def __getattr__(self_, attr):
else:
raise AttributeError(f"'{self_.cls.__name__}.param' object has no attribute {attr!r}")


@as_uninitialized
def _set_name(self_, name):
self = self_.param.self
self.name=name


@as_uninitialized
def _generate_name(self_):
self = self_.param.self
self.param._set_name('%s%05d' % (self.__class__.__name__ ,object_count))


@as_uninitialized
def _setup_params(self_,**params):
"""
Expand Down Expand Up @@ -2003,7 +2019,6 @@ def set_default(self_,param_name,value):
cls = self_.cls
setattr(cls,param_name,value)


def add_parameter(self_, param_name, param_obj):
"""
Add a new Parameter object into this object's class.
Expand Down Expand Up @@ -2056,46 +2071,61 @@ def params(self_, parameter_name=None):

def update(self_, *args, **kwargs):
"""
For the given dictionary or iterable or set of param=value keyword arguments,
sets the corresponding parameter of this object or class to the given value.
For the given dictionary or iterable or set of param=value
keyword arguments, sets the corresponding parameter of this
object or class to the given value.

May also be used as a context manager to temporarily set and
then reset parameter values.
"""
BATCH_WATCH = self_.self_or_cls.param._BATCH_WATCH
self_.self_or_cls.param._BATCH_WATCH = True
restore = self_._update(*args, **kwargs)
return _ParametersRestorer(parameters=self_, restore=restore)

def _update(self_, *args, **kwargs):
BATCH_WATCH = self_._BATCH_WATCH
self_._BATCH_WATCH = True
self_or_cls = self_.self_or_cls
if args:
if len(args) == 1 and not kwargs:
kwargs = args[0]
else:
self_.self_or_cls.param._BATCH_WATCH = False
raise ValueError("%s.update accepts *either* an iterable or key=value pairs, not both" %
(self_or_cls.name))
self_._BATCH_WATCH = False
raise ValueError(
f"{self_.cls.__name__}.param.update accepts *either* an iterable "
"or key=value pairs, not both."
)

trigger_params = [k for k in kwargs
if ((k in self_.self_or_cls.param) and
hasattr(self_.self_or_cls.param[k], '_autotrigger_value'))]
trigger_params = [
k for k in kwargs
if k in self_ and hasattr(self_[k], '_autotrigger_value')
]

for tp in trigger_params:
self_.self_or_cls.param[tp]._mode = 'set'

values = self_.values()
restore = {k: values[k] for k, v in kwargs.items() if k in values}

for (k, v) in kwargs.items():
if k not in self_or_cls.param:
self_.self_or_cls.param._BATCH_WATCH = False
raise ValueError(f"'{k}' is not a parameter of {self_or_cls.name}")
if k not in self_:
self_._BATCH_WATCH = False
raise ValueError(f"{k!r} is not a parameter of {self_.cls.__name__}")
try:
setattr(self_or_cls, k, v)
except:
self_.self_or_cls.param._BATCH_WATCH = False
self_._BATCH_WATCH = False
raise

self_.self_or_cls.param._BATCH_WATCH = BATCH_WATCH
self_._BATCH_WATCH = BATCH_WATCH
if not BATCH_WATCH:
self_._batch_call_watchers()

for tp in trigger_params:
p = self_.self_or_cls.param[tp]
p = self_[tp]
p._mode = 'reset'
setattr(self_or_cls, tp, p._autotrigger_reset_value)
p._mode = 'set-reset'
return restore

# PARAM3_DEPRECATION
@_deprecated(extra_msg="Use instead `.param.update`")
Expand All @@ -2121,7 +2151,6 @@ def set_param(self_, *args,**kwargs):
(self_or_cls.name))
return self_.update(kwargs)


def objects(self_, instance=True):
"""
Returns the Parameters of this instance or class
Expand Down
168 changes: 168 additions & 0 deletions tests/testparameterizedobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,174 @@ def test_error_if_non_param_in_constructor(self):
with pytest.raises(TypeError, match=re.escape(msg)):
TestPO(not_a_param=2)

def test_update_class(self):
class P(param.Parameterized):
x = param.Parameter()

p = P()

P.param.update(x=10)

assert P.x == p.x == 10

def test_update_context_class(self):
class P(param.Parameterized):
x = param.Parameter(10)

p = P()

with P.param.update(x=20):
assert P.x == p.x == 20

assert P.x == p.x == 10

def test_update_class_watcher(self):
class P(param.Parameterized):
x = param.Parameter()

events = []
P.param.watch(events.append, 'x')

P.param.update(x=10)

assert len(events) == 1
assert events[0].name == 'x' and events[0].new == 10

def test_update_context_class_watcher(self):
class P(param.Parameterized):
x = param.Parameter(0)

events = []
P.param.watch(events.append, 'x')

with P.param.update(x=20):
pass

assert len(events) == 2
assert events[0].name == 'x' and events[0].new == 20
assert events[1].name == 'x' and events[1].new == 0

def test_update_instance_watcher(self):
class P(param.Parameterized):
x = param.Parameter()

p = P()

events = []
p.param.watch(events.append, 'x')

p.param.update(x=10)

assert len(events) == 1
assert events[0].name == 'x' and events[0].new == 10

def test_update_context_instance_watcher(self):
class P(param.Parameterized):
x = param.Parameter(0)

p = P()

events = []
p.param.watch(events.append, 'x')

with p.param.update(x=20):
pass

assert len(events) == 2
assert events[0].name == 'x' and events[0].new == 20
assert events[1].name == 'x' and events[1].new == 0

def test_update_error_not_param_class(self):
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
TestPO.param.update(not_a_param=1)

def test_update_error_not_param_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
t.param.update(not_a_param=1)

def test_update_context_error_not_param_class(self):
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
with TestPO.param.update(not_a_param=1):
pass

def test_update_context_error_not_param_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match="'not_a_param' is not a parameter of TestPO"):
with t.param.update(not_a_param=1):
pass

def test_update_error_while_updating(self):
class P(param.Parameterized):
x = param.Parameter(0, readonly=True)

with pytest.raises(TypeError):
P.param.update(x=1)

assert P.x == 0

with pytest.raises(TypeError):
with P.param.update(x=1):
pass

assert P.x == 0

p = P()

with pytest.raises(TypeError):
p.param.update(x=1)

assert p.x == 0

with pytest.raises(TypeError):
with p.param.update(x=1):
pass

assert p.x == 0

def test_update_error_dict_and_kwargs_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
t.param.update(dict(a=1), a=1)

def test_update_context_error_dict_and_kwargs_instance(self):
t = TestPO(inst='foo')
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
with t.param.update(dict(a=1), a=1):
pass

def test_update_error_dict_and_kwargs_class(self):
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
TestPO.param.update(dict(a=1), a=1)

def test_update_context_error_dict_and_kwargs_class(self):
with pytest.raises(ValueError, match=re.escape("TestPO.param.update accepts *either* an iterable or key=value pairs, not both")):
with TestPO.param.update(dict(a=1), a=1):
pass

def test_update_context_single_parameter(self):
t = TestPO(inst='foo')
with t.param.update(inst='bar'):
assert t.inst == 'bar'
assert t.inst == 'foo'

def test_update_context_does_not_set_other_params(self):
t = TestPO(inst='foo')
events = []
t.param.watch(events.append, list(t.param), onlychanged=False)
with t.param.update(inst='bar'):
pass
assert len(events) == 2
assert all(e.name == 'inst' for e in events)

def test_update_context_multi_parameter(self):
t = TestPO(inst='foo', notinst=1)
with t.param.update(inst='bar', notinst=2):
assert t.inst == 'bar'
assert t.notinst == 2
assert t.inst == 'foo'
assert t.notinst == 1


class some_fn(param.ParameterizedFunction):
__test__ = False
Expand Down