-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data loader merlin graph transforms and embeddings #37
Conversation
merlin/loader/loader_base.py
Outdated
@@ -438,21 +467,21 @@ def _to_tensor(self, df): | |||
tensor in the appropriate library, with an optional | |||
dtype kwarg to do explicit casting if need be | |||
""" | |||
raise NotImplementedError | |||
return df.to_cupy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be removed
merlin/loader/loader_base.py
Outdated
|
||
def _get_device_ctx(self, dev): | ||
""" | ||
One of the mandatory functions a child class needs | ||
to implement. Maps from a GPU index to a framework | ||
context object for placing tensors on specific GPUs | ||
""" | ||
raise NotImplementedError | ||
return cp.cuda.Device(dev) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be removed
merlin/loader/loader_base.py
Outdated
|
||
def _cast_to_numpy_dtype(self, dtype): | ||
""" | ||
Get the numpy dtype from the framework dtype. | ||
""" | ||
raise NotImplementedError | ||
return dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be removed
# are all operators going to need to know about lists as tuples? | ||
# seems like we could benefit from an object here that encapsulates | ||
# both lists and scalar tensor types? | ||
if self.transforms: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about creating a comprehensive "column" class that can be sub-classed to ScalarColumn and ListColumn. This will hide the tuple format behind a df series type interface that will be more friendly to the other parts of merlin, i.e. the graph. The use case is what if I want to do some after dataloader inbatch processing to a list column. It will be easier to abstract that tuple representation (values, nnz) and allow the user to not have to worry about keeping track of all that.
This PR requires a core change in https://github.com/NVIDIA-Merlin/core/pull/152/files |
rerun tests |
|
||
|
||
class TFEmbeddingOperator(BaseOperator): | ||
"""Create an operator that will apply a tf embedding table to supplied indices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of these are repeated with small tweaks, would be nice to be able to converge so we dont have three operators for the same thing just using different inputs.
from merlin.schema import ColumnSchema, Schema, Tags | ||
|
||
|
||
class TorchEmbeddingOperator(BaseOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as in tensorflow case, so many of the operators are just a little different, but to avoid confusions and allow users to understand more clearly uses and use cases we have kept these operators separate. Would be good to move to a state where we just have one operator for this (as previously stated).
|
||
|
||
@pytest.fixture(scope="session") | ||
def rev_embedding_ids(embedding_ids, tmpdir_factory): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverse embeddings is used to ensure that id_lookup is working correctly, in this case the indexes are reversed, [99999:1] , In embedding_ids above its [1:99999]. This allows us to use enumeration of batches to pull out the correct (what should be in the embeddings) values and assert they are what came back in each batch.
moved from private to public so needs new fork |
This PR adds the ability to run a merlin graph transforms over the batches of data that come out of the data loader. Operator introduced here is the embedding operators. Allowing for batch level additions of the embedding representations for records.