Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with no props available to _get_props() for __eq__() and __hash__() on depickling #256

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

VincentStark
Copy link

@VincentStark VincentStark commented Oct 11, 2024

As described in #198 and changed in #237, the _get_props() method can fail on depickling because some of the array attributes might not exist.

A fix in #237 only partially resolved the issue, and only for __hash__() operation.

In this PR, I offer a different approach: instead of modifying the op function __eq__() (that was broken with the same has no attribute 'index_variadic' message in our case), we update _get_props() method not to fail if certain attributes are missing.

@patrick-kidger
Copy link
Owner

Thanks for the PR!
This looks reasonable to me, although I think we need to keep the previous hash implementation: I believe Python requires that a hash be stable across an object's lifetime.

Do you have a MWE for when this is necessary? I'd like to add this as a test, to be sure we don't break this in the future.

@VincentStark VincentStark force-pushed the fix/no-props-on-depickling-operations branch from 51eac9e to 2939a3c Compare October 16, 2024 03:30
@VincentStark VincentStark force-pushed the fix/no-props-on-depickling-operations branch from 2939a3c to cd9b939 Compare October 16, 2024 03:30
@VincentStark
Copy link
Author

VincentStark commented Oct 16, 2024

@patrick-kidger thank you for your quick response! I've updated the PR, and it took a little time to isolate MWE with the help of @BrandonAtomicAI

MWE

It fails on unpickle with has no attribute 'index_variadic' in the existing version, and doesn't fail in a version with my changes.
Please let me know if you need anything else.

test_model.py

import equinox as eqx
from jaxtyping import Array, Float


class Linear(eqx.Module):
    weight: eqx.nn.Linear

    def __init__(self, key):
        self.weight = eqx.nn.Linear(10, 10, key=key)

    def __call__(self, x: Float[Array, "n m"]):
        return self.weight(x)

test_pickle.py

import cloudpickle
import test_model

import jax.random as jr

cloudpickle.register_pickle_by_value(test_model)

net = test_model.Linear(key=jr.key(0))

pickled_model = cloudpickle.dumps(net)

with open("test.pkl", "wb") as f:
    f.write(pickled_model)

test_unpickle.py

import cloudpickle

with open("test.pkl", "rb") as f:
    net = cloudpickle.load(f)

@patrick-kidger
Copy link
Owner

Hmm, what versions are you using? With

jaxtyping==0.2.34
equinox==0.11.7
cloudpickle==3.1.0

and Python 3.11, I can't reproduce your issue.

I run first python test_pickle.py followed by python test_unpickle.py, and this does not raise an exception.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants