Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type checking for pytrees #3340

Open
shoyer opened this issue Jun 5, 2020 · 16 comments
Open

Type checking for pytrees #3340

shoyer opened this issue Jun 5, 2020 · 16 comments

Comments

@shoyer
Copy link
Collaborator

shoyer commented Jun 5, 2020

Glad to hear JAX will eventually have annotations! I was following this issue for updates. I actually started trying to add annotations myself before I realized how much work it was going to be.

One thing it would be nice to expose as soon as possible though are base types for tensors and pytrees:

Tensor = Union[np.ndarray, jnp.ndarray]  # probably needs more things in the union like tracers and DeviceArray?
PyTree = Union[Tensor,
               'PyTreeLike',
               Tuple['PyTree', ...],
               List['PyTree'],
               Dict[Hashable, 'PyTree'],
               None]

Originally posted by @NeilGirdhar in #1555 (comment)

@NeilGirdhar
Copy link
Contributor

Thanks a lot for adding this 😄

@gnecula
Copy link
Collaborator

gnecula commented Jun 8, 2020

There are two goals for typing hints: to catch errors and to improve readability by documenting intended usage. I think that readability is the most important and easiest to achieve, e.g., just use PyTree everywhere a pytree is expected.

The title of this issue is about type checking, i.e., preventing type errors. I feel that this is much harder especially in the parts of the the code that was written to make heavy use of dynamic typing. The type checkers try hard, but this is a very murky area. We have recently filed a bug with pytype in presence of Union.

My thinking is that for PyTree we get most of the benefit from defining a type alias PyTree = Any, along with a comment explaining what pytrees are. Or course, for other cases that are closer to static typing we should give proper type definitions.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Jun 11, 2020

Yes, you're right about the benefit of documentation. That's why I'm already using PyTree everywhere in my code.

Could you link me to your bug with pytype? I'm curious?

My thinking is that for PyTree we get most of the benefit from defining a type alias PyTree = Any, along with a comment explaining what pytrees are. Or course, for other cases that are closer to static typing we should give proper type definitions.

I agree for now, but pretty soon np.ndarray and as I understand it jnp.ndarray are going to be type-checked, which means that my PyTree will "wake up". At that point, we will get static typing? We just need an easy way to get PyTreeLike to accept all of the user-defined classes. Someone mentioned Protocols, but I've never used them.

@shoyer
Copy link
Collaborator Author

shoyer commented Jun 12, 2020

I can imagine two ways in which it might make sense to type check pytrees:

  • Types checking over the types of leaves, e.g., a pytree of JAX arrays
  • Type checking over the structure of trees, e.g., to verify that two pytrees are both dictionaries with the same keys

In theory, the first type could be checked with Generic and the second type could be checked (at least partially) with TypeVar.

@billmark
Copy link

Has there been any progress on this bug? What's the current best-known-practice? Is it still to define PyTree=Any?

@hawkinsp
Copy link
Collaborator

Yes, in the sense that I think we're fairly confident it cannot be done without a Python type checker that supports recursive types, which mypy and pytype (our current type checkers) do not. So, you should write Any.

@XuehaiPan
Copy link
Contributor

Yes, in the sense that I think we're fairly confident it cannot be done without a Python type checker that supports recursive types, which mypy and pytype (our current type checkers) do not. So, you should write Any.

mypy now supports recursive types since v0.981, and will be enabled by default since v0.990. E.g.:

JSON = Union[Dict[str, 'JSON'], List['JSON'], str, int, float, bool, None]

@carlosgmartin
Copy link
Contributor

Given the above, are there any plans to add a standard jax.PyTree type soon?

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 17, 2023

How would you do better than PyTree = Any, given the fact that arbitrary types can be registered as pytrees at runtime?

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Apr 17, 2023

given the fact that arbitrary types can be registered as pytrees at runtime?

FYI, if ABCMeta.register is ever supported by type checkers (e.g., by MyPy), then you could make PyTree a class that inherits from abc.ABC and register all of your pytree types, which would be visible to type checkers.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 17, 2023

That sounds like it could be a good solution someday

@shoyer
Copy link
Collaborator Author

shoyer commented Apr 17, 2023

At this point, I would suggest defining project specific PyTree types, e.g.,

PyTree = dict[str, 'PyTree'] | list['PyTree'] | jax.Array

Project specific types could (in principle) be handled with generics, e.g.,

from typing import Generic, TypeVar
import dataclasses

T = TypeVar('T')

@dataclasses.dataclass
class MyStruct(Generic[T]):
    x: T
    y: T

PyTree = dict[str, 'PyTree'] | list['PyTree'] | MyStruct['PyTree'] | jax.Array

Potentially there's some room for libraries like Flax to define struct types that are compatible with this sort of type checking, but otherwise I don't think there's much be to done in JAX.

@cgarciae
Copy link
Collaborator

cgarciae commented Apr 19, 2023

I think the only real useful definition of Pytree is:

Pytree = Any

because tree_map accepts both pytrees and leafs (which are all non registered types) so the following is valid:

mul2 = lambda x: x * 2

jax.tree_map(mul2, 3)                        # 6
jax.tree_map(mul2, "hi ")                    # "hi hi "
jax.tree_map(mul2, TypeThatImplementsMul())  # ???
# and so on...

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 19, 2023

Note that jaxtyping has type annotations for pytrees that get around the above issues by only checking at runtime: that could be a good solution depending on your use-case.

But for the reasons above, it appears that Python's static type-checking spec doesn't have much to say about PyTrees as currently implemented.

@shoyer
Copy link
Collaborator Author

shoyer commented Apr 19, 2023

I think the only real useful definition of Pytree is:

Pytree = Any

because tree_map accepts both pytrees and leafs (which are all non registered types) so the following is valid:

Right, but individual projects can probably guarantee that they are only going to use a restricted set of types in PyTree leaves. For example, every leaf that is a neural net parameter needs to be a (float) array.

@cgarciae
Copy link
Collaborator

In Flax we use something like this but its not informative of the leaf types at all:

Collection = Mapping[str, Any]
FrozenVariableDict = FrozenDict[str, Collection]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants