-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update grids.py and interpolation.py (#69)
- Loading branch information
Showing
9 changed files
with
316 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,3 @@ | ||
import jax | ||
|
||
from lcm import mark | ||
|
||
jax.config.update("jax_platform_name", "cpu") | ||
|
||
|
||
__all__ = ["mark"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,111 @@ | ||
"""Functions to generate and work with different kinds of grids. | ||
Grid generation functions have the arguments: | ||
Grid generation functions must have the following signature: | ||
- start | ||
- stop | ||
- n_points | ||
Signature (start: Scalar, stop: Scalar, n_points: int) -> jax.Array | ||
interpolation info functions have the arguments | ||
They take start and end points and create a grid of points between them. | ||
- value | ||
- start | ||
- stop | ||
- n_points | ||
Interpolation info functions must have the following signature: | ||
Signature ( | ||
value: Scalar, | ||
start: Scalar, | ||
stop: Scalar, | ||
n_points: int | ||
) -> Scalar | ||
They take the information required to generate a grid, and return an index corresponding | ||
to the value, which is a point in the space but not necessarily a grid point. | ||
Some of the arguments will not be used by all functions but the aligned interface makes | ||
it easy to call functions interchangeably. | ||
""" | ||
|
||
import jax.numpy as jnp | ||
from jax import Array | ||
|
||
from lcm.typing import Scalar | ||
|
||
|
||
def linspace(start: Scalar, stop: Scalar, n_points: int) -> Array: | ||
"""Wrapper around jnp.linspace. | ||
Returns a linearly spaced grid between start and stop with n_points, including both | ||
endpoints. | ||
def linspace(start, stop, n_points): | ||
""" | ||
return jnp.linspace(start, stop, n_points) | ||
|
||
|
||
def get_linspace_coordinate(value, start, stop, n_points): | ||
"""Map a value into the input needed for map_coordinates.""" | ||
def get_linspace_coordinate( | ||
value: Scalar, | ||
start: Scalar, | ||
stop: Scalar, | ||
n_points: int, | ||
) -> Scalar: | ||
"""Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" | ||
step_length = (stop - start) / (n_points - 1) | ||
return (value - start) / step_length | ||
|
||
|
||
def logspace(start, stop, n_points): | ||
start_lin = jnp.log(start) | ||
stop_lin = jnp.log(stop) | ||
return jnp.logspace(start_lin, stop_lin, n_points, base=2.718281828459045) | ||
def logspace(start: Scalar, stop: Scalar, n_points: int) -> Array: | ||
"""Wrapper around jnp.logspace. | ||
Returns a logarithmically spaced grid between start and stop with n_points, | ||
including both endpoints. | ||
def get_logspace_coordinate(value, start, stop, n_points): | ||
"""Map a value into the input needed for map_coordinates.""" | ||
start_lin = jnp.log(start) | ||
stop_lin = jnp.log(stop) | ||
value_lin = jnp.log(value) | ||
From the JAX documentation: | ||
mapped_point_lin = get_linspace_coordinate(value_lin, start_lin, stop_lin, n_points) | ||
In linear space, the sequence starts at base ** start (base to the power of | ||
start) and ends with base ** stop [...]. | ||
# Calculate lower and upper point on log/exp scale | ||
step_length = (stop_lin - start_lin) / (n_points - 1) | ||
rank_lower_gridpoint = jnp.floor(mapped_point_lin) | ||
rank_upper_gridpoint = rank_lower_gridpoint + 1 | ||
""" | ||
start_linear = jnp.log(start) | ||
stop_linear = jnp.log(stop) | ||
return jnp.logspace(start_linear, stop_linear, n_points, base=jnp.e) | ||
|
||
|
||
def get_logspace_coordinate( | ||
value: Scalar, | ||
start: Scalar, | ||
stop: Scalar, | ||
n_points: int, | ||
) -> Scalar: | ||
"""Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" | ||
# Transform start, stop, and value to linear scale | ||
start_linear = jnp.log(start) | ||
stop_linear = jnp.log(stop) | ||
value_linear = jnp.log(value) | ||
|
||
# Calc | ||
lower_gridpoint = jnp.exp(start_lin + step_length * rank_lower_gridpoint) | ||
upper_gridpoint = jnp.exp(start_lin + step_length * rank_upper_gridpoint) | ||
# Calculate coordinate in linear space | ||
coordinate_in_linear_space = get_linspace_coordinate( | ||
value_linear, | ||
start_linear, | ||
stop_linear, | ||
n_points, | ||
) | ||
|
||
# Calculate rank of lower and upper point in logarithmic space | ||
rank_lower_gridpoint = jnp.floor(coordinate_in_linear_space) | ||
rank_upper_gridpoint = rank_lower_gridpoint + 1 | ||
|
||
# Calculate transformed mapped point | ||
decimal = (value - lower_gridpoint) / (upper_gridpoint - lower_gridpoint) | ||
return rank_lower_gridpoint + decimal | ||
# Calculate lower and upper point in logarithmic space | ||
step_length_linear = (stop_linear - start_linear) / (n_points - 1) | ||
lower_gridpoint = jnp.exp(start_linear + step_length_linear * rank_lower_gridpoint) | ||
upper_gridpoint = jnp.exp(start_linear + step_length_linear * rank_upper_gridpoint) | ||
|
||
# Calculate the decimal part of coordinate | ||
logarithmic_step_size_at_coordinate = upper_gridpoint - lower_gridpoint | ||
distance_from_lower_gridpoint = value - lower_gridpoint | ||
|
||
# If the distance from the lower gridpoint is zero, the coordinate corresponds to | ||
# the rank of the lower gridpoint. The other extreme is when the distance is equal | ||
# to the logarithmic step size at the coordinate, in which case the coordinate | ||
# corresponds to the rank of the upper gridpoint. For values in between, the | ||
# coordinate lies on a linear scale between the ranks of the lower and upper | ||
# gridpoints. | ||
decimal_part = distance_from_lower_gridpoint / logarithmic_step_size_at_coordinate | ||
return rank_lower_gridpoint + decimal_part |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.