Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MRI perturbation interface #358

Merged
merged 4 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
allow_selection: bool = False,
cell_size: int = 24,
per_page: int = 20,
page: int = 0,
):
"""Gallery view of a DataFrame.

Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
allow_selection=allow_selection,
cell_size=cell_size,
per_page=per_page,
page=page,
)

def _get_ipython_height(self):
Expand Down
23 changes: 16 additions & 7 deletions meerkat/interactive/graph/reactivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def reactive(
fn: Callable = None,
nested_return: bool = False,
skip_fn: Callable[..., bool] = None,
backend_only: bool = False,
) -> Callable:
"""Internal decorator that is used to mark a function as reactive.
This is only meant for internal use, and users should use the
Expand Down Expand Up @@ -88,7 +89,12 @@ def concat(df1: mk.DataFrame, df2: mk.DataFrame) -> mk.DataFrame:
if fn is None:
# need to make passing args to the args optional
# note: all of the args passed to the decorator MUST be optional
return partial(reactive, nested_return=nested_return, skip_fn=skip_fn)
return partial(
reactive,
nested_return=nested_return,
skip_fn=skip_fn,
backend_only=backend_only,
)

# Built-in functions cannot be wrapped in reactive.
# They have to be converted to a lambda function first and then run.
Expand Down Expand Up @@ -122,6 +128,7 @@ def wrapper(*args, **kwargs):
# Then, Store(list)[0] should also return a Store.
# TODO (arjun): These if this assumption holds.
nonlocal nested_return
nonlocal backend_only
nonlocal fn

_is_unmarked_context = is_unmarked_context()
Expand Down Expand Up @@ -221,13 +228,15 @@ def _fn_wrapper(*args, **kwargs):

# Wrap the Result in NodeMixin objects
if nested_return:
result = _nested_apply(result, fn=_wrap_outputs)
result = _nested_apply(
result, fn=partial(_wrap_outputs, backend_only=backend_only)
)
elif isinstance(result, NodeMixin):
result = result
elif isinstance(result, Iterator):
result = _IteratorStore(result)
result = _IteratorStore(result, backend_only=backend_only)
else:
result = Store(result)
result = Store(result, backend_only=backend_only)

# If the object is a ReactifiableMixin, we should turn
# reactivity on.
Expand Down Expand Up @@ -336,14 +345,14 @@ def _add_op_as_child(op: Operation, *nodeables: NodeMixin):
nodeable.inode.add_child(op.inode, triggers=triggers)


def _wrap_outputs(obj):
def _wrap_outputs(obj, backend_only=False):
from meerkat.interactive.graph.store import Store, _IteratorStore

if isinstance(obj, NodeMixin):
return obj
elif isinstance(obj, Iterator):
return _IteratorStore(obj)
return Store(obj)
return _IteratorStore(obj, backend_only=backend_only)
return Store(obj, backend_only=backend_only)


def _create_nodes_for_nodeables(*nodeables: NodeMixin):
Expand Down