-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data loader merlin graph transforms and embeddings #37
Changes from 4 commits
db89cab
736c8e8
c1572fa
0eb8d81
6334d91
f978c57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
import threading | ||
import warnings | ||
from collections import OrderedDict | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
|
@@ -36,6 +37,8 @@ | |
make_df, | ||
pull_apart_list, | ||
) | ||
from merlin.dag import BaseOperator, ColumnSelector, DictArray, Graph, Node | ||
from merlin.dag.executors import LocalExecutor | ||
from merlin.io import shuffle_df | ||
from merlin.schema import Tags | ||
|
||
|
@@ -59,6 +62,7 @@ def __init__( | |
global_size=None, | ||
global_rank=None, | ||
drop_last=False, | ||
transforms=None, | ||
): | ||
self.dataset = dataset | ||
self.batch_size = batch_size | ||
|
@@ -79,6 +83,7 @@ def __init__( | |
) | ||
dataset.schema = dataset.infer_schema() | ||
|
||
self.schema = dataset.schema | ||
self.sparse_names = [] | ||
self.sparse_max = {} | ||
self.sparse_as_dense = set() | ||
|
@@ -126,6 +131,30 @@ def __init__( | |
self._batch_itr = None | ||
self._workers = None | ||
|
||
if transforms is not None: | ||
|
||
if isinstance(transforms, List): | ||
carry_node = Node(ColumnSelector("*")) | ||
for transform in transforms: | ||
# check that each transform is an operator: | ||
if not isinstance(transform, BaseOperator): | ||
raise TypeError(f"Detected invalid transform, {type(transform)}") | ||
carry_node = carry_node >> transform | ||
transform_graph = Graph(carry_node) | ||
elif type(transforms, Graph): | ||
transform_graph = transforms | ||
self.transforms = transform_graph.construct_schema(self.schema).output_node | ||
self.schema = self.transforms.output_schema | ||
# should we make one main local executor and hold that on dataloader? | ||
# Or build dynamically per batch? | ||
# is there a reason we might expose this to the user? | ||
# change to something other than local? | ||
self.executor = LocalExecutor() | ||
else: | ||
# Like this to be more explicit about what occurs. | ||
self.transforms = None | ||
self.executor = None | ||
|
||
@property | ||
def _buff(self): | ||
if self.__buff is None: | ||
|
@@ -438,21 +467,21 @@ def _to_tensor(self, df): | |
tensor in the appropriate library, with an optional | ||
dtype kwarg to do explicit casting if need be | ||
""" | ||
raise NotImplementedError | ||
return df.to_cupy() | ||
|
||
def _get_device_ctx(self, dev): | ||
""" | ||
One of the mandatory functions a child class needs | ||
to implement. Maps from a GPU index to a framework | ||
context object for placing tensors on specific GPUs | ||
""" | ||
raise NotImplementedError | ||
return cp.cuda.Device(dev) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be removed |
||
|
||
def _cast_to_numpy_dtype(self, dtype): | ||
""" | ||
Get the numpy dtype from the framework dtype. | ||
""" | ||
raise NotImplementedError | ||
return dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be removed |
||
|
||
def _split_fn(self, tensor, idx, axis=0): | ||
raise NotImplementedError | ||
|
@@ -551,8 +580,35 @@ def _handle_tensors(self, tensors): | |
labels = None | ||
if len(self.label_names) > 0: | ||
labels = X.pop(self.label_names[0]) | ||
|
||
# with tensors all in one dictionary | ||
# apply transforms graph here against the tensors | ||
# | ||
# tensors = local_executor.transform_data(tensors) | ||
|
||
# bad thing here is that we dont have the labels, what is required, for some | ||
# reason by op transform logic? | ||
# bad thing here is that some of this entries are lists, which are tuples? | ||
# are all operators going to need to know about lists as tuples? | ||
# seems like we could benefit from an object here that encapsulates | ||
# both lists and scalar tensor types? | ||
if self.transforms: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should think about creating a comprehensive "column" class that can be sub-classed to ScalarColumn and ListColumn. This will hide the tuple format behind a df series type interface that will be more friendly to the other parts of merlin, i.e. the graph. The use case is what if I want to do some after dataloader inbatch processing to a list column. It will be easier to abstract that tuple representation (values, nnz) and allow the user to not have to worry about keeping track of all that. |
||
X = self.executor.transform(DictArray(X), [self.transforms]) | ||
|
||
return X, labels | ||
|
||
def _pack(self, gdf): | ||
if isinstance(gdf, np.ndarray): | ||
return gdf | ||
elif hasattr(gdf, "to_dlpack") and callable(getattr(gdf, "to_dlpack")): | ||
return gdf.to_dlpack() | ||
elif hasattr(gdf, "to_numpy") and callable(getattr(gdf, "to_numpy")): | ||
gdf = gdf.to_numpy() | ||
if isinstance(gdf[0], list): | ||
gdf = np.stack(gdf) | ||
return gdf | ||
return gdf.toDlpack() | ||
|
||
|
||
class ChunkQueue: | ||
"""This class takes partitions (parts) from an merlin.io.Dataset | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# flake8: noqa | ||
from merlin.loader.ops.embeddings.tf_embedding_op import ( | ||
Numpy_Mmap_TFEmbedding, | ||
Numpy_TFEmbeddingOperator, | ||
TFEmbeddingOperator, | ||
) | ||
from merlin.loader.ops.embeddings.torch_embedding_op import ( | ||
Numpy_Mmap_TorchEmbedding, | ||
Numpy_TorchEmbeddingOperator, | ||
TorchEmbeddingOperator, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be removed