Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data loader merlin graph transforms and embeddings #37

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions merlin/loader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import threading
import warnings
from collections import OrderedDict
from typing import List

import numpy as np

Expand All @@ -36,6 +37,8 @@
make_df,
pull_apart_list,
)
from merlin.dag import BaseOperator, ColumnSelector, DictArray, Graph, Node
from merlin.dag.executors import LocalExecutor
from merlin.io import shuffle_df
from merlin.schema import Tags

Expand All @@ -59,6 +62,7 @@ def __init__(
global_size=None,
global_rank=None,
drop_last=False,
transforms=None,
):
self.dataset = dataset
self.batch_size = batch_size
Expand All @@ -79,6 +83,7 @@ def __init__(
)
dataset.schema = dataset.infer_schema()

self.schema = dataset.schema
self.sparse_names = []
self.sparse_max = {}
self.sparse_as_dense = set()
Expand Down Expand Up @@ -126,6 +131,30 @@ def __init__(
self._batch_itr = None
self._workers = None

if transforms is not None:

if isinstance(transforms, List):
carry_node = Node(ColumnSelector("*"))
for transform in transforms:
# check that each transform is an operator:
if not isinstance(transform, BaseOperator):
raise TypeError(f"Detected invalid transform, {type(transform)}")
carry_node = carry_node >> transform
transform_graph = Graph(carry_node)
elif type(transforms, Graph):
transform_graph = transforms
self.transforms = transform_graph.construct_schema(self.schema).output_node
self.schema = self.transforms.output_schema
# should we make one main local executor and hold that on dataloader?
# Or build dynamically per batch?
# is there a reason we might expose this to the user?
# change to something other than local?
self.executor = LocalExecutor()
else:
# Like this to be more explicit about what occurs.
self.transforms = None
self.executor = None

@property
def _buff(self):
if self.__buff is None:
Expand Down Expand Up @@ -551,8 +580,35 @@ def _handle_tensors(self, tensors):
labels = None
if len(self.label_names) > 0:
labels = X.pop(self.label_names[0])

# with tensors all in one dictionary
# apply transforms graph here against the tensors
#
# tensors = local_executor.transform_data(tensors)

# bad thing here is that we dont have the labels, what is required, for some
# reason by op transform logic?
# bad thing here is that some of this entries are lists, which are tuples?
# are all operators going to need to know about lists as tuples?
# seems like we could benefit from an object here that encapsulates
# both lists and scalar tensor types?
if self.transforms:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about creating a comprehensive "column" class that can be sub-classed to ScalarColumn and ListColumn. This will hide the tuple format behind a df series type interface that will be more friendly to the other parts of merlin, i.e. the graph. The use case is what if I want to do some after dataloader inbatch processing to a list column. It will be easier to abstract that tuple representation (values, nnz) and allow the user to not have to worry about keeping track of all that.

X = self.executor.transform(DictArray(X), [self.transforms])

return X, labels

def _pack(self, gdf):
if isinstance(gdf, np.ndarray):
return gdf
elif hasattr(gdf, "to_dlpack") and callable(getattr(gdf, "to_dlpack")):
return gdf.to_dlpack()
elif hasattr(gdf, "to_numpy") and callable(getattr(gdf, "to_numpy")):
gdf = gdf.to_numpy()
if isinstance(gdf[0], list):
gdf = np.stack(gdf)
return gdf
return gdf.toDlpack()


class ChunkQueue:
"""This class takes partitions (parts) from an merlin.io.Dataset
Expand Down
15 changes: 15 additions & 0 deletions merlin/loader/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
27 changes: 27 additions & 0 deletions merlin/loader/ops/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# flake8: noqa
from merlin.loader.ops.embeddings.tf_embedding_op import (
Numpy_Mmap_TFEmbedding,
Numpy_TFEmbeddingOperator,
TFEmbeddingOperator,
)
from merlin.loader.ops.embeddings.torch_embedding_op import (
Numpy_Mmap_TorchEmbedding,
Numpy_TorchEmbeddingOperator,
TorchEmbeddingOperator,
)
248 changes: 248 additions & 0 deletions merlin/loader/ops/embeddings/tf_embedding_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import numpy as np
import tensorflow as tf

from merlin.core.protocols import Transformable
from merlin.dag import BaseOperator
from merlin.dag.selector import ColumnSelector
from merlin.schema import ColumnSchema, Schema, Tags


class TFEmbeddingOperator(BaseOperator):
"""Create an operator that will apply a tf embedding table to supplied indices.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of these are repeated with small tweaks, would be nice to be able to converge so we dont have three operators for the same thing just using different inputs.

This operator allows the user to supply an id lookup table if the indices supplied
via the id_lookup_table. Embedding table is stored in host memory.

Parameters
----------
embeddings : np.ndarray
numpy ndarray representing embedding values
lookup_key : str, optional
the name of the column that will be used as indices, by default "id"
embedding_name : str, optional
name of new column of embeddings, added to output, by default "embeddings"
id_lookup_table : np.array, optional
numpy array of values that represent embedding indices, by default None
"""

def __init__(
self,
embeddings: np.ndarray,
lookup_key: str = "id",
embedding_name: str = "embeddings",
id_lookup_table=None,
):
self.embeddings = (
embeddings if isinstance(embeddings, tf.Tensor) else tf.convert_to_tensor(embeddings)
)
self.lookup_key = lookup_key
self.embedding_name = embedding_name
self.id_lookup_table = id_lookup_table

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
indices = transformable[self.lookup_key]
if self.id_lookup_table:
indices = np.in1d(self.id_lookup_table, indices)
embeddings = tf.nn.embedding_lookup(self.embeddings, indices)
transformable[self.embedding_name] = embeddings
return transformable

def compute_output_schema(
self,
input_schema: Schema,
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Creates the output schema for this operator.

Parameters
----------
input_schema : Schema
schema coming from ancestor nodes
col_selector : ColumnSelector
subselection of columns to apply to this operator
prev_output_schema : Schema, optional
the output schema of the previously executed operators, by default None

Returns
-------
Schema
Schema representing the correct output for this operator.
"""
col_schemas = []
for _, col_schema in input_schema.column_schemas.items():
col_schemas.append(col_schema)
col_schemas.append(
ColumnSchema(
name=self.embedding_name,
tags=[Tags.CONTINUOUS],
dtype=self.embeddings.dtype.as_numpy_dtype,
is_list=True,
is_ragged=False,
)
)

return Schema(col_schemas)


class Numpy_TFEmbeddingOperator(BaseOperator):
"""Create an embedding table from supplied embeddings to add embedding entry
to records based on supplied indices. Support for indices lookup table is available.
Embedding table is stored in host memory.

Parameters
----------
embeddings : np.ndarray
numpy ndarray representing embedding values
lookup_key : str, optional
the name of the column that will be used as indices, by default "id"
embedding_name : str, optional
name of new column of embeddings, added to output, by default "embeddings"
id_lookup_table : np.array, optional
numpy array of values that represent embedding indices, by default None
"""

def __init__(
self,
embeddings: np.ndarray,
lookup_key: str = "id",
embedding_name: str = "embeddings",
id_lookup_table=None,
):
self.embeddings = embeddings
self.lookup_key = lookup_key
self.embedding_name = embedding_name
self.id_lookup_table = id_lookup_table

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
indices = transformable[self.lookup_key]
if self.id_lookup_table:
indices = np.in1d(self.id_lookup_table, indices)
embeddings = self.embeddings[indices]
transformable[self.embedding_name] = tf.convert_to_tensor(embeddings)
return transformable

def compute_output_schema(
self,
input_schema: Schema,
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Creates the output schema for this operator.

Parameters
----------
input_schema : Schema
schema coming from ancestor nodes
col_selector : ColumnSelector
subselection of columns to apply to this operator
prev_output_schema : Schema, optional
the output schema of the previously executed operators, by default None

Returns
-------
Schema
Schema representing the correct output for this operator.
"""
col_schemas = []
for _, col_schema in input_schema.column_schemas.items():
col_schemas.append(col_schema)
col_schemas.append(
ColumnSchema(
name=self.embedding_name,
tags=[Tags.CONTINUOUS],
dtype=self.embeddings.dtype,
is_list=True,
is_ragged=False,
)
)

return Schema(col_schemas)


class Numpy_Mmap_TFEmbedding(BaseOperator):
"""Operator loads numpy embedding table from file using memory map to be used to create
tensorflow embedding representations. This allows for larger than host memory embedding
tables to be used for embedding lookups. The only limit to the size is what fits in
storage, preferred storage device is SSD for faster lookups.

Parameters
----------
embedding_npz : numpy ndarray file
file holding numpy ndarray representing embedding table
ids_lookup_npz : numpy array file, optional
file holding numpy array of values that represent embedding indices, by default None
lookup_key : str, optional
the name of the column that will be used as indices, by default "id"
embedding_name : str, optional
name of new column of embeddings, added to output, by default "embeddings"
transform_function : _type_, optional
function that will transform embedding from numpy to torch, by default None
"""

def __init__(
self,
embedding_npz,
ids_lookup_npz=None,
lookup_key="id",
embedding_name="embeddings",
transform_function=None,
):
self.embeddings = np.load(embedding_npz, mmap_mode="r")
self.id_lookup = np.load(ids_lookup_npz) if ids_lookup_npz else None
self.lookup_key = lookup_key
self.embedding_name = embedding_name
self.transform_function = tf.convert_to_tensor

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
ids_tensor = transformable[self.lookup_key]
if self.id_lookup:
ids_tensor = np.in1d(self.id_lookup[:, 0], ids_tensor)
embeddings = self.embeddings[ids_tensor]
if self.transform_function:
transformable[self.embedding_name] = self.transform_function(embeddings)
else:
transformable[self.embedding_name] = embeddings
return transformable

def compute_output_schema(
self,
input_schema: Schema,
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Creates the output schema for this operator.

Parameters
----------
input_schema : Schema
schema coming from ancestor nodes
col_selector : ColumnSelector
subselection of columns to apply to this operator
prev_output_schema : Schema, optional
the output schema of the previously executed operators, by default None

Returns
-------
Schema
Schema representing the correct output for this operator.
"""
col_schemas = []
for _, col_schema in input_schema.column_schemas.items():
col_schemas.append(col_schema)
col_schemas.append(
ColumnSchema(
name=self.embedding_name,
tags=[Tags.CONTINUOUS],
dtype=self.embeddings.dtype,
is_list=True,
is_ragged=False,
)
)

return Schema(col_schemas)
Loading