Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Support for NamedTuple #21428

Closed
wants to merge 1 commit into from

Conversation

jamesr66a
Copy link
Collaborator

@jamesr66a jamesr66a commented Jun 5, 2019

Resolves https://github.com/pytorch/lockdown/issues/18

This implements NamedTuple by taking advantage of the existing names field in TupleType.

TODO: This currently doesn't retain the NamedTuple-ness through serialization. Discussed with @suo offline, we can probably make a way to define an anonymous NamedTuple in script (e.g. NamedTuple('Foo', [('a', int), ('b', float), ('c', List[float])]) and serialize that
TODO: implement support for calling the constructor with kwargs

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries labels Jun 5, 2019

@torch.jit.script
def foo(x) -> float:
fv = FeatureVector(3.0, [3.0], 3.0) # noqa
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flake8 was freaking out here in my pre-commit hook for some reason

@jamesr66a jamesr66a requested a review from suo June 6, 2019 16:55
@jamesr66a jamesr66a changed the title [WIP][JIT] Support for NamedTuple [JIT] Support for NamedTuple Jun 6, 2019
@jamesr66a jamesr66a force-pushed the named_tuple_support2 branch 2 times, most recently from 826e4f2 to 75ac46a Compare June 10, 2019 17:25
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This duplication is sad, but I can't really see a better way short of merging the tuple and class concepts together. Can you add a lockdown P1 to explore whether that's possible? I think it should work if we implement a scalar replacement of aggregates pass.

Otherwise, left some comments inline

torch/csrc/jit/script/init.cpp Outdated Show resolved Hide resolved
aten/src/ATen/core/jit_type.h Outdated Show resolved Hide resolved
aten/src/ATen/core/jit_type.h Outdated Show resolved Hide resolved
aten/src/ATen/core/jit_type.h Outdated Show resolved Hide resolved
torch/jit/__init__.py Outdated Show resolved Hide resolved
torch/csrc/jit/script/init.cpp Outdated Show resolved Hide resolved
torch/csrc/jit/script/init.cpp Outdated Show resolved Hide resolved
@@ -192,9 +193,21 @@ struct GenericDict;

struct CAFFE2_API Tuple : public List<IValue> {
using List<IValue>::List;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, in your serialization PR you'll need to fix up the pickler as well—currently the type member will just get ignore and it'll get pickled as a tuple, which isn't what we want.

@jamesr66a jamesr66a requested a review from suo June 10, 2019 21:39
torch/jit/__init__.py Outdated Show resolved Hide resolved
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Much simpler than before. Only one substantive comment

torch/csrc/jit/script/init.cpp Outdated Show resolved Hide resolved
torch/jit/__init__.py Outdated Show resolved Hide resolved
Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. I only glanced the tupleptr reverts, I assume they are just straight reverts. In the future, please stack your PRs so they are easier to review 💯

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The structure looks good, but I don't there is any type checking being done on tuple construction.
I'd also like to see what happens on unification e.g. a named and unnamed tuple on either side of an if-statement.

torch/csrc/jit/script/sugared_value.cpp Outdated Show resolved Hide resolved
torch/csrc/jit/script/sugared_value.h Outdated Show resolved Hide resolved
torch/csrc/jit/script/sugared_value.cpp Show resolved Hide resolved
test/test_jit_py3.py Show resolved Hide resolved
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Schema checking code is buggy. It is an easy fix but this isn't correct yet.

torch/csrc/jit/script/sugared_value.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Schema checking code is buggy. It is an easy fix but this isn't correct yet.

torch/csrc/jit/script/sugared_value.cpp Outdated Show resolved Hide resolved
Summary:
Resolves pytorch/lockdown#18

This implements NamedTuple by taking advantage of the existing `names` field in `TupleType`.

TODO: This currently doesn't retain the NamedTuple-ness through serialization. Discussed with suo offline, we can probably make a way to define an anonymous NamedTuple in script (e.g. `NamedTuple('Foo', [('a', int), ('b', float), ('c', List[float])])` and serialize that
TODO: implement support for calling the constructor with kwargs
Pull Request resolved: pytorch#21428

Differential Revision: D15741564

fbshipit-source-id: b981d3c0f058afec5d6a7500d989bb10ac898278
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 15, 2019
Summary:
Resolves pytorch/lockdown#18

This implements NamedTuple by taking advantage of the existing `names` field in `TupleType`.

TODO: This currently doesn't retain the NamedTuple-ness through serialization. Discussed with suo offline, we can probably make a way to define an anonymous NamedTuple in script (e.g. `NamedTuple('Foo', [('a', int), ('b', float), ('c', List[float])])` and serialize that
TODO: implement support for calling the constructor with kwargs
Pull Request resolved: pytorch/pytorch#21428

Differential Revision: D15741564

Pulled By: jamesr66a

fbshipit-source-id: c077cbcea1880675ca6deb340a9ec78f824a136c
jamesr66a pushed a commit that referenced this pull request Jun 15, 2019
Summary:
Resolves https://github.com/pytorch/lockdown/issues/18

This implements NamedTuple by taking advantage of the existing `names` field in `TupleType`.

TODO: This currently doesn't retain the NamedTuple-ness through serialization. Discussed with suo offline, we can probably make a way to define an anonymous NamedTuple in script (e.g. `NamedTuple('Foo', [('a', int), ('b', float), ('c', List[float])])` and serialize that
TODO: implement support for calling the constructor with kwargs
Pull Request resolved: #21428

Differential Revision: D15741564

fbshipit-source-id: b981d3c0f058afec5d6a7500d989bb10ac898278
@facebook-github-bot
Copy link
Contributor

@jamesr66a merged this pull request in 4bcc72f.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
caffe2 Merged module: cpp Related to C++ API module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants