Skip to content

Commit

Permalink
migrate base operator to operator and base_runtime to runtime (#359)
Browse files Browse the repository at this point in the history
* migrate base operator to operator

* add baseoperator alias

* fix init file alias import and remove alias from operator module
  • Loading branch information
jperez999 authored Jul 12, 2023
1 parent 6cf6a02 commit 2d68307
Show file tree
Hide file tree
Showing 21 changed files with 106 additions and 62 deletions.
3 changes: 2 additions & 1 deletion docs/source/api/merlin.dag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Merlin DAG
.. autosummary::
:toctree: generated

merlin.dag.BaseOperator
merlin.dag.Operator

merlin.dag.Graph
merlin.dag.Node
merlin.dag.ColumnSelector
6 changes: 4 additions & 2 deletions merlin/dag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

# flake8: noqa
from merlin.dag.base_operator import BaseOperator, DataFormats, Supports

from merlin.dag.graph import Graph
from merlin.dag.node import Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
from merlin.dag.operator import DataFormats, Operator, Supports
from merlin.dag.selector import ColumnSelector
from merlin.dag.utils import group_values_offsets, ungroup_values_offsets

BaseOperator = Operator
4 changes: 2 additions & 2 deletions merlin/dag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from collections import deque
from typing import Dict

from merlin.dag.base_operator import BaseOperator
from merlin.dag.node import (
Node,
_combine_schemas,
iter_nodes,
postorder_iter_nodes,
preorder_iter_nodes,
)
from merlin.dag.operator import Operator
from merlin.dag.ops.stat_operator import StatOperator
from merlin.schema import Schema

Expand All @@ -39,7 +39,7 @@ class Graph:
"""

def __init__(self, output_node: Node):
if isinstance(output_node, BaseOperator):
if isinstance(output_node, Operator):
output_node = Node.construct_from(output_node)

self.output_node = output_node
Expand Down
30 changes: 16 additions & 14 deletions merlin/dag/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@
import collections.abc
from typing import List, Union

from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.ops import ConcatColumns, GroupingOp, SelectionOp, SubsetColumns, SubtractionOp
from merlin.dag.ops.udf import UDF
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema

Nodable = Union[
"Node",
BaseOperator,
Operator,
str,
List[str],
ColumnSelector,
List[Union["Node", BaseOperator, str, List[str], ColumnSelector]],
List[Union["Node", Operator, str, List[str], ColumnSelector]],
]


Expand Down Expand Up @@ -253,28 +253,30 @@ def validate_schemas(self, root_schema: Schema, strict_dtypes: bool = False):
)

def __rshift__(self, operator):
"""Transforms this Node by applying an BaseOperator
"""Transforms this Node by applying an Operator
Parameters
-----------
operators: BaseOperator or callable
Returns
-------
Node
Parameters
-----------
operators: Operator
or callable
Returns
-------
Node
"""

if callable(operator) and not (
isinstance(operator, type) and issubclass(operator, BaseOperator)
isinstance(operator, type) and issubclass(operator, Operator)
):
# implicit lambdaop conversion.
operator = UDF(operator)

if isinstance(operator, type) and issubclass(operator, BaseOperator):
if isinstance(operator, type) and issubclass(operator, Operator):
# handle case where an operator class is passed
operator = operator()

if not isinstance(operator, BaseOperator):
if not isinstance(operator, Operator):
raise ValueError(f"Expected operator or callable, got {operator.__class__}")

child = type(self)()
Expand Down Expand Up @@ -553,7 +555,7 @@ def construct_from(
return Node(ColumnSelector([nodable]))
if isinstance(nodable, ColumnSelector):
return Node(nodable)
elif isinstance(nodable, BaseOperator):
elif isinstance(nodable, Operator):
node = Node()
node.op = nodable
return node
Expand Down
41 changes: 40 additions & 1 deletion merlin/dag/base_operator.py → merlin/dag/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, List, Optional, Union

import merlin.dag
import merlin.dag.utils
from merlin.core.protocols import Transformable
from merlin.dag.selector import ColumnSelector
from merlin.schema import ColumnSchema, Schema
Expand Down Expand Up @@ -55,7 +56,8 @@ class DataFormats(Flag):
CUPY_DICT_ARRAY = auto()


class BaseOperator:
# pylint: disable=too-many-public-methods
class Operator:
"""
Base class for all operator classes.
"""
Expand Down Expand Up @@ -409,3 +411,40 @@ def _get_columns(self, df, selector):
return {col_name: df[col_name] for col_name in selector.names}
else:
return df[selector.names]

@property
def export_name(self):
"""
Provides a clear common english identifier for this operator.
Returns
-------
String
Name of the current class as spelled in module.
"""
return self.__class__.__name__.lower()

def export(self, path: str, input_schema: Schema, output_schema: Schema, **kwargs):
"""
Export the class object as a config and all related files to the user defined path.
Parameters
----------
path : str
Artifact export path
input_schema : Schema
A schema with information about the inputs to this operator.
output_schema : Schema
A schema with information about the outputs of this operator.
params : dict, optional
Parameters dictionary of key, value pairs stored in exported config, by default None.
node_id : int, optional
The placement of the node in the graph (starts at 1), by default None.
version : int, optional
The version of the operator, by default 1.
Returns
-------
model_config: dict
The config for the exported operator.
"""
4 changes: 2 additions & 2 deletions merlin/dag/ops/add_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.schema.tags import Tags


class AddMetadata(BaseOperator):
class AddMetadata(Operator):
"""
This operator will add user defined tags and properties
to a Schema.
Expand Down
4 changes: 2 additions & 2 deletions merlin/dag/ops/concat_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
#

from merlin.core.protocols import Transformable
from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema


class ConcatColumns(BaseOperator):
class ConcatColumns(Operator):
"""
This operator class provides an implementation for the `+` operator used in constructing graphs.
"""
Expand Down
6 changes: 3 additions & 3 deletions merlin/dag/ops/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# limitations under the License.
#
from merlin.core.protocols import Transformable
from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector


class Rename(BaseOperator):
class Rename(Operator):
"""This operation renames columns by one of several methods:
- using a user defined lambda function to transform column names
Expand Down Expand Up @@ -59,7 +59,7 @@ def transform(
)
return transformable

transform.__doc__ = BaseOperator.transform.__doc__
transform.__doc__ = Operator.transform.__doc__

def column_mapping(self, col_selector):
column_mapping = {}
Expand Down
4 changes: 2 additions & 2 deletions merlin/dag/ops/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import logging

from merlin.core.protocols import Transformable
from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema

LOG = logging.getLogger("SelectionOp")


class SelectionOp(BaseOperator):
class SelectionOp(Operator):
"""
This operator class provides an implementation of the behavior of selection (e.g. input) nodes.
"""
Expand Down
4 changes: 2 additions & 2 deletions merlin/dag/ops/stat_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import dask.dataframe as dd

from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector


class StatOperator(BaseOperator):
class StatOperator(Operator):
"""
Base class for statistical operator classes. This adds a 'fit' and 'finalize' method
on top of the Operator class.
Expand Down
4 changes: 2 additions & 2 deletions merlin/dag/ops/subset_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
#

from merlin.core.protocols import Transformable
from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector


class SubsetColumns(BaseOperator):
class SubsetColumns(Operator):
"""
This operator class provides an implementation for the `[]` operator
used in constructing graphs.
Expand Down
4 changes: 2 additions & 2 deletions merlin/dag/ops/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from __future__ import annotations

from merlin.core.protocols import Transformable
from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema


class SubtractionOp(BaseOperator):
class SubtractionOp(Operator):
"""
This operator class provides an implementation for the `-` operator used in constructing graphs.
"""
Expand Down
6 changes: 3 additions & 3 deletions merlin/dag/ops/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

from merlin.core.dispatch import make_df
from merlin.core.protocols import Transformable
from merlin.dag.base_operator import BaseOperator
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector


class UDF(BaseOperator):
class UDF(Operator):
"""
UDF allows you to apply row level functions to a dataframe or TensorTable
Expand Down Expand Up @@ -82,7 +82,7 @@ def transform(
# return input type data
return make_df(new_df)

transform.__doc__ = BaseOperator.transform.__doc__
transform.__doc__ = Operator.transform.__doc__

@property
def dependencies(self):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion merlin/dag/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __radd__(self, other):
return self + other

def __rshift__(self, operator):
if isinstance(operator, type) and issubclass(operator, merlin.dag.BaseOperator):
if isinstance(operator, type) and issubclass(operator, merlin.dag.Operator):
# handle case where an operator class is passed
operator = operator()

Expand Down
12 changes: 6 additions & 6 deletions tests/unit/dag/ops/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import pytest

from merlin.core.protocols import Transformable
from merlin.dag.base_operator import BaseOperator
from merlin.dag.executors import DaskExecutor, LocalExecutor
from merlin.dag.graph import Graph
from merlin.dag.operator import Operator
from merlin.dag.ops.stat_operator import StatOperator
from merlin.dag.ops.subgraph import Subgraph
from merlin.dag.selector import ColumnSelector
Expand All @@ -29,9 +29,9 @@

@pytest.mark.parametrize("engine", ["parquet"])
def test_subgraph(df):
ops = ["x"] >> BaseOperator() >> BaseOperator()
ops = ["x"] >> Operator() >> Operator()
subgraph_op = Subgraph("subgraph", ops)
main_graph_ops = ["x", "y"] >> BaseOperator() >> subgraph_op >> BaseOperator()
main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator()

main_graph = Graph(main_graph_ops)

Expand All @@ -57,7 +57,7 @@ def fit_finalize(self, dask_stats):

fit_test_op = FitTestOp()
subgraph_op = Subgraph("subgraph", ["x"] >> fit_test_op)
main_graph_ops = ["x", "y"] >> BaseOperator() >> subgraph_op >> BaseOperator()
main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator()

main_graph = Graph(main_graph_ops)
main_graph.construct_schema(dataset.schema)
Expand All @@ -72,7 +72,7 @@ def fit_finalize(self, dask_stats):

@pytest.mark.parametrize("engine", ["parquet"])
def test_subgraph_looping(dataset):
class LoopingTestOp(BaseOperator):
class LoopingTestOp(Operator):
def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
Expand All @@ -84,7 +84,7 @@ def transform(
subgraph,
loop_until=lambda transformable: (transformable["x"] > 5.0).all(),
)
main_graph_ops = ["x", "y"] >> BaseOperator() >> subgraph_op >> BaseOperator()
main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator()

main_graph = Graph(main_graph_ops)
main_graph.construct_schema(dataset.schema)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/dag/test_base_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#
import pytest

from merlin.dag.base_operator import BaseOperator as Operator
from merlin.dag.graph import Graph
from merlin.dag.operator import Operator
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema

Expand Down
Loading

0 comments on commit 2d68307

Please sign in to comment.