Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pad #748

Merged
merged 38 commits into from
Jul 19, 2024
Merged

Implement pad #748

merged 38 commits into from
Jul 19, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 4, 2024

Description

Implement pt.pad, following the np.pad API with feature parity.

Very preliminary draft, uploading it in this state so I can ask @ricardoV94 to look at the _linear_ramp_pad function and tell me if I'm missing something obvious related to shapes. It should follow numpy.lib.arraypad._get_linear_ramps. Also the reflection pad uses a scan, curious if we can avoid that somehow or if we think it will be no big deal (probably the 2nd).

Also I'm not sure where to put this. I put it in tensor/basic but it might be better in tensor/extra_ops?

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94
Copy link
Member

ricardoV94 commented May 4, 2024

What about a padding.py file?

pytensor/tensor/basic.py Outdated Show resolved Hide resolved
@jessegrabowski
Copy link
Member Author

Sure, I'll make a new file. It's just not my default. I agree it doesn't belong in basic.

@jessegrabowski
Copy link
Member Author

Not quite 1:1 on numpy features but close. The more exotic padding schemes I would need more time to understand.

Still needs jax/numba overloads, but these should be very trivial.

Copy link

codecov bot commented May 11, 2024

Codecov Report

Attention: Patch coverage is 94.30380% with 18 lines in your changes missing coverage. Please review.

Project coverage is 81.48%. Comparing base (c6d85d1) to head (bbeb300).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #748      +/-   ##
==========================================
+ Coverage   81.38%   81.48%   +0.09%     
==========================================
  Files         172      174       +2     
  Lines       46868    47166     +298     
  Branches    11423    11471      +48     
==========================================
+ Hits        38145    38434     +289     
- Misses       6542     6548       +6     
- Partials     2181     2184       +3     
Files Coverage Δ
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/pad.py 100.00% <100.00%> (ø)
pytensor/tensor/subtensor.py 89.31% <100.00%> (+0.11%) ⬆️
pytensor/tensor/pad.py 97.14% <97.14%> (ø)
pytensor/tensor/extra_ops.py 87.64% <80.95%> (-0.98%) ⬇️

... and 4 files with indirect coverage changes

@jessegrabowski jessegrabowski changed the title Implement pt.pad Implement pad May 11, 2024
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great so far, left some small suggestions.

pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
pytensor/tensor/pad.py Outdated Show resolved Hide resolved
pytensor/tensor/pad.py Outdated Show resolved Hide resolved
@jessegrabowski
Copy link
Member Author

Draft of the JAX overload. Need your input on the Pad OpFromGraph. I needed some way to hang on to the keyword arguments.

It seems like there might be a difference between how JAX and numpy handle mode=mean padding, because tests pass against numpy but not against JAX. I'll investigate more carefully, but it might be a JAX bug (I doubt this padding mode is used ever)

I also think my loopy pads (symmetric, wrap) need to be redone, because they are failing a new test that arbitrarily pads every dimension of an nd input differently. So all that probably needs a re-design from the ground up.

@ricardoV94
Copy link
Member

I also think my loopy pads (symmetric, wrap) need to be redone, because they are failing a new test that arbitrarily pads every dimension of an nd input differently. So all that probably needs a re-design from the ground up.

You may need an operation per dimension

@ricardoV94
Copy link
Member

Regarding JAX do you need to implement a specific dispatch? For instance for the einsum I don't think we'll need because the OFG expression will be as good as what they do internally (since we copied it from them)

@jessegrabowski
Copy link
Member Author

No idea on the JAX dispatch. I just assumed I should.

@jessegrabowski
Copy link
Member Author

Don't understand why the doctest for pad is failing

@ricardoV94
Copy link
Member

Don't understand why the doctest for pad is failing

It says there is an output that was not expected. If you have a print somewhere, you need to always test it afterward

@ricardoV94
Copy link
Member

You can also run doctest locally btw

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 13, 2024

Something like pytest --doctest-modules pytensor/tensor/pad.py --verbose

@ricardoV94
Copy link
Member

We should open a follow up issue for performance. With the reshape and concatenation, we're doing a lot of copies. We should see how much better it would be to have scans with set_subtensors like you tried halfway.

@jessegrabowski
Copy link
Member Author

I kind of just want to skip the segfault test and come back to it later. I am trying to debug, but not really sure what's going on. It runs fine when NUMBA_DISABLE_NJIT flag is set. My suspicion is an out-of-range index, but I can't reproduce the error in an NB. I'd be fine just removing the numba tests all-together and considering it unsupported, but every other mode passes. So idk. I just want to get this over the finish line.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is huge!

@jessegrabowski jessegrabowski merged commit 981688c into pymc-devs:main Jul 19, 2024
58 of 59 checks passed
Ch0ronomato pushed a commit to Ch0ronomato/pytensor that referenced this pull request Aug 15, 2024
* Add `pt.pad`

* Refactor linspace, logspace, and geomspace to match numpy implementation

* Add `pt.flip`

* Move `flip` to `tensor/subtensor.py`, add docstring

* Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor.tensor`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Implement pt.pad
3 participants