Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update grids.py and interpolation.py #69

Merged
merged 12 commits into from
May 24, 2024

Conversation

timmens
Copy link
Member

@timmens timmens commented Apr 23, 2024

In this PR, I

  1. Add types to the grids.py module
    For this I define the Scalar type, which includes int, float, and jax.Array, because many JAX and LCM functions workon zero-dimensiona JAX arrays (and that often faster compared to float's).
  2. Delete the interpolation.py module, because it was not used anymore
  3. Create more precise types for the grid specification of discrete and continuous variables.
  4. Delete a JAX config update that constraints computation on the CPU, which was added for testing.
  5. Added more unit and illustrative tests

Copy link
Member

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a brief look because I was thinking about extending the scope to log-spaced grids... And they are already there, had not realised this. Great!

Looking at that, however: Could this be extended to let the user supply her own grid directly? Thinking about cases where you want to include certain points directly (e.g., kinks in the tax schedule, ...). Probably more a matter of documenation / UI design than anything else, not necessarily for this PR.

src/lcm/grids.py Outdated Show resolved Hide resolved
@timmens timmens changed the title Update grids.py, interfaces.py and interpolation.py Update grids.py and interpolation.py May 22, 2024
@timmens
Copy link
Member Author

timmens commented May 22, 2024

Could this be extended to let the user supply her own grid directly?

In principle, yes, but it is not completely trivial. I have opened an issue but set the priority to low for now (#76)

@timmens timmens requested a review from hmgaudecker May 22, 2024 18:21
@timmens
Copy link
Member Author

timmens commented May 22, 2024

@hmgaudecker

I am fairly happy with the new changes and would be okay with merging this PR without your review if you do not find the time right now. Of course, if you find the time, I'd be happy to get your feedback.

There is no time pressure on this PR, so a review in the next 1-2 weeks would suffice.

Copy link
Member

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code is excellent as usual, thanks!

Just one clarifying question where I cannot dig deeper myself right now.

Also, since I continue to be confused about what is happening for values outside a grid, maybe it is useful to add a test defining that behaviour? (apologies if we have one and I missed it)

stop: Scalar,
n_points: int,
) -> Scalar:
"""Map a value into the input needed for jax.scipy.ndimage.map_coordinates."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not quite find out on the phone what is happening when value is outside the boundaries of the grid now. Is it possible to extrapolate? (found the options for the Jax/scipy function online, but no explanation what they mean, nor where we call the function)

Also it is not fully clear to me why we use the interpolation from ndimage instead of spicy.interpolate? E.g., this one would directly address #76 and probably the interface would be closer to the interpolator for rectangular grids? Are those not available in Jax?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If the value is outside the grid, the functions in grids.py will return a coordinate smaller than 0 or larger than the last index (at least in some dimension). In my opinion, the behavior of the interpolator, in this case, should not be handled in grids.py. The interpolator is defined and used in the function_evaluator.py module (look for the _get_interpolator function). Here we set the option mode="nearest". This will choose the closest point in the grid (i.e., in the 1D case, if the coordinate is -1, it will return the value at the coordinate 0). I will clean up the function_evaluator.py module soon, in which case this will be better documented. For now, I've added an illustrative test to test_grids.py.

  2. That is a good question. I will reply to you about this and add the result to the developer notes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re 1. — I think we'll need to discuss that once more. Taking the example of consumption, I would think that "nearest" means that its marginal utility is zero outside the grid, which will lead to very strange behaviour if the grid has been chosen poorly. Will be very hard to "debug" then (no bug in the strict sense, ofc).

My guess would be that "reflect" may do what I would think is the best approximation (take the two extreme gridpoints and extrapolate).

Maybe we could expose that kwarg instead of hard-coding it in _get_interpolator? In any case, we should use whatever is available in the interpolation function we'll end up using.

Not necessarily for this PR, feel free to open an issue instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re 1. - These are good points. I've opened an issue task regarding this. Will be worked on once I tackle the interpolation functionality in LCM.

Re 2. - I've also added a task to that issue to compare the feature set of RegularGridInterpolator to the strategy we are using now. First tests show that it may be slower, and there is some confusion on whether it works with log-grids. The other scipy interpolation functions are not implemented in JAX.

@timmens timmens merged commit 20db0ec into main May 24, 2024
5 checks passed
@timmens timmens deleted the update-grids-interfaces-interpolation branch May 24, 2024 08:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants