-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type checking for pytrees #3340
Comments
Thanks a lot for adding this 😄 |
There are two goals for typing hints: to catch errors and to improve readability by documenting intended usage. I think that readability is the most important and easiest to achieve, e.g., just use The title of this issue is about type checking, i.e., preventing type errors. I feel that this is much harder especially in the parts of the the code that was written to make heavy use of dynamic typing. The type checkers try hard, but this is a very murky area. We have recently filed a bug with pytype in presence of My thinking is that for PyTree we get most of the benefit from defining a type alias |
Yes, you're right about the benefit of documentation. That's why I'm already using Could you link me to your bug with pytype? I'm curious?
I agree for now, but pretty soon |
I can imagine two ways in which it might make sense to type check pytrees:
In theory, the first type could be checked with |
Has there been any progress on this bug? What's the current best-known-practice? Is it still to define PyTree=Any? |
Yes, in the sense that I think we're fairly confident it cannot be done without a Python type checker that supports recursive types, which mypy and pytype (our current type checkers) do not. So, you should write |
JSON = Union[Dict[str, 'JSON'], List['JSON'], str, int, float, bool, None] |
Given the above, are there any plans to add a standard |
How would you do better than |
FYI, if |
That sounds like it could be a good solution someday |
At this point, I would suggest defining project specific PyTree types, e.g.,
Project specific types could (in principle) be handled with generics, e.g.,
Potentially there's some room for libraries like Flax to define struct types that are compatible with this sort of type checking, but otherwise I don't think there's much be to done in JAX. |
I think the only real useful definition of Pytree is: Pytree = Any because mul2 = lambda x: x * 2
jax.tree_map(mul2, 3) # 6
jax.tree_map(mul2, "hi ") # "hi hi "
jax.tree_map(mul2, TypeThatImplementsMul()) # ???
# and so on... |
Note that But for the reasons above, it appears that Python's static type-checking spec doesn't have much to say about PyTrees as currently implemented. |
Right, but individual projects can probably guarantee that they are only going to use a restricted set of types in PyTree leaves. For example, every leaf that is a neural net parameter needs to be a (float) array. |
In Flax we use something like this but its not informative of the leaf types at all: Collection = Mapping[str, Any]
FrozenVariableDict = FrozenDict[str, Collection] |
Glad to hear JAX will eventually have annotations! I was following this issue for updates. I actually started trying to add annotations myself before I realized how much work it was going to be.
One thing it would be nice to expose as soon as possible though are base types for tensors and pytrees:
Originally posted by @NeilGirdhar in #1555 (comment)
The text was updated successfully, but these errors were encountered: