-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(treespec): make PyTreeSpec.is_prefix
to be consistent with PyTreeSpec.flatten_up_to
#94
Conversation
Codecov ReportAll modified lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #94 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 4 4
Lines 427 428 +1
=========================================
+ Hits 427 428 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
4c6bb14
to
a918781
Compare
5373b9b
to
8f9d1d3
Compare
ea093c6
to
6b3533d
Compare
6b3533d
to
7310a7a
Compare
@@ -88,7 +88,7 @@ bool PyTreeSpec::operator==(const PyTreeSpec& other) const { | |||
// NOLINTNEXTLINE[readability-qualified-auto] | |||
for (auto a = m_traversal.begin(); a != m_traversal.end(); ++a, ++b) { | |||
if (a->kind != b->kind || a->arity != b->arity || | |||
(a->node_data.ptr() == nullptr) != (b->node_data.ptr() == nullptr) || | |||
static_cast<bool>(a->node_data) != static_cast<bool>(b->node_data) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change this? what's the original type of a->node_data, is it a pointer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a->node_data
is a py::object
. pybind11 suggests to use operator bool()
to check whether the py::object
instance is set. It is equivalent to this->ptr() != nullptr
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good job!
if (a->kind != b->kind || a->arity != b->arity || | ||
(a->node_data.ptr() == nullptr) != (b->node_data.ptr() == nullptr) || | ||
if (a->arity != b->arity || | ||
static_cast<bool>(a->node_data) != static_cast<bool>(b->node_data) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
@@ -288,8 +382,8 @@ py::list PyTreeSpec::Entries() const { | |||
return py::list{root.node_entries}; | |||
} | |||
switch (root.kind) { | |||
case PyTreeKind::None: | |||
case PyTreeKind::Leaf: { | |||
case PyTreeKind::Leaf: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reorder the switch cases to match the order of the enum definitions.
|
||
case PyTreeKind::DefaultDict: | ||
case PyTreeKind::Deque: | ||
case PyTreeKind::Custom: | ||
case PyTreeKind::Custom: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm
Description
Describe your changes in detail.
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Put an
x
in all the boxes that apply:Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!
make format
. (required)make lint
. (required)make test
pass. (required)