Skip to content

Commit

Permalink
Update grids.py and interpolation.py (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored May 24, 2024
1 parent b6730db commit 20db0ec
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 114 deletions.
5 changes: 0 additions & 5 deletions src/lcm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import jax

from lcm import mark

jax.config.update("jax_platform_name", "cpu")


__all__ = ["mark"]
29 changes: 15 additions & 14 deletions src/lcm/function_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import lcm.grids as grids_module
from lcm.functools import all_as_kwargs
from lcm.interfaces import ContinuousGridInfo, ContinuousGridType


def get_function_evaluator(
Expand Down Expand Up @@ -117,11 +118,11 @@ def get_function_evaluator(
# ==============================================================================
# create functions to find coordinates for the interpolation
# ==============================================================================
for var, grid_info in space_info.interpolation_info.items():
for var, grid_spec in space_info.interpolation_info.items():
funcs[f"__{var}_coord__"] = _get_coordinate_finder(
in_name=input_prefix + var,
grid_type=grid_info.kind,
grid_info=grid_info.specs,
grid_type=grid_spec.kind,
grid_info=grid_spec.info,
)

# ==============================================================================
Expand Down Expand Up @@ -219,30 +220,30 @@ def lookup_wrapper(*args, **kwargs):
return lookup_wrapper


def _get_coordinate_finder(in_name, grid_type, grid_info):
def _get_coordinate_finder(
in_name: str,
grid_type: ContinuousGridType,
grid_info: ContinuousGridInfo,
):
"""Create a function that translates a value into coordinates on a grid.
The resulting coordinates can be used to do linear interpolation via
jax.scipy.ndimage.map_coordinates.
Args:
in_name (str): Name via which the value to be translated into coordinates
will be passed into the resulting function.
grid_type (str): Type of the grid, e.g. "linspace" or "logspace". The type of
grid must be implemented in lcm.grids.
grid_info (dict): Dict with information that defines the grid. E.g. for a
linspace those are {"start": float, "stop": float, "n_points": int}. See
lcm.grids for details.
in_name: Name via which the value to be translated into coordinates will be
passed into the resulting function.
grid_type: Type of the grid, e.g. "linspace" or "logspace". The type of grid
must be implemented in lcm.grids.
grid_info: Information on how to build the grid, e.g. start, stop, and n_points.
Returns:
callable: A callable with keyword-only argument [in_name] that translates a
value into coordinates on a grid.
"""
grid_info = {} if grid_info is None else grid_info

raw_func = getattr(grids_module, f"get_{grid_type}_coordinate")
partialled_func = partial(raw_func, **grid_info)
partialled_func = partial(raw_func, **grid_info._asdict())

@with_signature(args=[in_name])
def find_coordinate(*args, **kwargs):
Expand Down
116 changes: 84 additions & 32 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,111 @@
"""Functions to generate and work with different kinds of grids.
Grid generation functions have the arguments:
Grid generation functions must have the following signature:
- start
- stop
- n_points
Signature (start: Scalar, stop: Scalar, n_points: int) -> jax.Array
interpolation info functions have the arguments
They take start and end points and create a grid of points between them.
- value
- start
- stop
- n_points
Interpolation info functions must have the following signature:
Signature (
value: Scalar,
start: Scalar,
stop: Scalar,
n_points: int
) -> Scalar
They take the information required to generate a grid, and return an index corresponding
to the value, which is a point in the space but not necessarily a grid point.
Some of the arguments will not be used by all functions but the aligned interface makes
it easy to call functions interchangeably.
"""

import jax.numpy as jnp
from jax import Array

from lcm.typing import Scalar


def linspace(start: Scalar, stop: Scalar, n_points: int) -> Array:
"""Wrapper around jnp.linspace.
Returns a linearly spaced grid between start and stop with n_points, including both
endpoints.
def linspace(start, stop, n_points):
"""
return jnp.linspace(start, stop, n_points)


def get_linspace_coordinate(value, start, stop, n_points):
"""Map a value into the input needed for map_coordinates."""
def get_linspace_coordinate(
value: Scalar,
start: Scalar,
stop: Scalar,
n_points: int,
) -> Scalar:
"""Map a value into the input needed for jax.scipy.ndimage.map_coordinates."""
step_length = (stop - start) / (n_points - 1)
return (value - start) / step_length


def logspace(start, stop, n_points):
start_lin = jnp.log(start)
stop_lin = jnp.log(stop)
return jnp.logspace(start_lin, stop_lin, n_points, base=2.718281828459045)
def logspace(start: Scalar, stop: Scalar, n_points: int) -> Array:
"""Wrapper around jnp.logspace.
Returns a logarithmically spaced grid between start and stop with n_points,
including both endpoints.
def get_logspace_coordinate(value, start, stop, n_points):
"""Map a value into the input needed for map_coordinates."""
start_lin = jnp.log(start)
stop_lin = jnp.log(stop)
value_lin = jnp.log(value)
From the JAX documentation:
mapped_point_lin = get_linspace_coordinate(value_lin, start_lin, stop_lin, n_points)
In linear space, the sequence starts at base ** start (base to the power of
start) and ends with base ** stop [...].
# Calculate lower and upper point on log/exp scale
step_length = (stop_lin - start_lin) / (n_points - 1)
rank_lower_gridpoint = jnp.floor(mapped_point_lin)
rank_upper_gridpoint = rank_lower_gridpoint + 1
"""
start_linear = jnp.log(start)
stop_linear = jnp.log(stop)
return jnp.logspace(start_linear, stop_linear, n_points, base=jnp.e)


def get_logspace_coordinate(
value: Scalar,
start: Scalar,
stop: Scalar,
n_points: int,
) -> Scalar:
"""Map a value into the input needed for jax.scipy.ndimage.map_coordinates."""
# Transform start, stop, and value to linear scale
start_linear = jnp.log(start)
stop_linear = jnp.log(stop)
value_linear = jnp.log(value)

# Calc
lower_gridpoint = jnp.exp(start_lin + step_length * rank_lower_gridpoint)
upper_gridpoint = jnp.exp(start_lin + step_length * rank_upper_gridpoint)
# Calculate coordinate in linear space
coordinate_in_linear_space = get_linspace_coordinate(
value_linear,
start_linear,
stop_linear,
n_points,
)

# Calculate rank of lower and upper point in logarithmic space
rank_lower_gridpoint = jnp.floor(coordinate_in_linear_space)
rank_upper_gridpoint = rank_lower_gridpoint + 1

# Calculate transformed mapped point
decimal = (value - lower_gridpoint) / (upper_gridpoint - lower_gridpoint)
return rank_lower_gridpoint + decimal
# Calculate lower and upper point in logarithmic space
step_length_linear = (stop_linear - start_linear) / (n_points - 1)
lower_gridpoint = jnp.exp(start_linear + step_length_linear * rank_lower_gridpoint)
upper_gridpoint = jnp.exp(start_linear + step_length_linear * rank_upper_gridpoint)

# Calculate the decimal part of coordinate
logarithmic_step_size_at_coordinate = upper_gridpoint - lower_gridpoint
distance_from_lower_gridpoint = value - lower_gridpoint

# If the distance from the lower gridpoint is zero, the coordinate corresponds to
# the rank of the lower gridpoint. The other extreme is when the distance is equal
# to the logarithmic step size at the coordinate, in which case the coordinate
# corresponds to the rank of the upper gridpoint. For values in between, the
# coordinate lies on a linear scale between the ranks of the lower and upper
# gridpoints.
decimal_part = distance_from_lower_gridpoint / logarithmic_step_size_at_coordinate
return rank_lower_gridpoint + decimal_part
48 changes: 38 additions & 10 deletions src/lcm/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import NamedTuple
from typing import Literal, NamedTuple

import numpy as np
import pandas as pd
from jax import Array

from lcm.typing import Scalar


class IndexerInfo(NamedTuple):
Expand All @@ -25,18 +27,44 @@ class IndexerInfo(NamedTuple):
out_name: str


class GridSpec(NamedTuple):
"""Information needed to define or interpret a grid.
class ContinuousGridInfo(NamedTuple):
"""Information on how to build a grid for a continuous variable.
Attributes:
start (Scalar): Start of the grid.
stop (Scalar): End of the grid.
n_points (int): Number of points in the grid.
"""

start: Scalar
stop: Scalar
n_points: int


ContinuousGridType = Literal["linspace", "logspace"]


class ContinuousGridSpec(NamedTuple):
"""Specification of a grid for continuous variables.
Contains all information necessary to build and work with a grid of a continuous
variable.
Attributes:
kind (str): Name of a grid type implemented in lcm.grids.
specs (dict, np.ndarray): Specification of the grid. E.g. {"start": float,
"stop": float, "n_points": int} for a linspace.
kind (ContinuousGridType): Name of a grid type implemented in lcm.grids.
info (ContinuousGridInfo): Information on how to build the grid. E.g., start,
stop, and n_points.
"""

kind: str
specs: dict | np.ndarray
kind: ContinuousGridType
info: ContinuousGridInfo


DiscreteGridSpec = Array
GridSpec = ContinuousGridSpec | DiscreteGridSpec


class Space(NamedTuple):
Expand Down Expand Up @@ -71,7 +99,7 @@ class SpaceInfo(NamedTuple):

axis_names: list[str]
lookup_info: dict[str, list[str]]
interpolation_info: dict[str, GridSpec]
interpolation_info: dict[str, ContinuousGridSpec]
indexer_infos: list[IndexerInfo]


Expand Down
41 changes: 23 additions & 18 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pandas as pd
from dags import get_ancestors
from dags.signature import with_signature
from jax import Array

import lcm.grids as grids_module
from lcm.create_params_template import create_params_template
from lcm.function_evaluator import get_label_translator
from lcm.functools import all_as_args, all_as_kwargs
from lcm.interfaces import GridSpec, Model
from lcm.interfaces import ContinuousGridInfo, ContinuousGridSpec, GridSpec, Model


def process_model(user_model):
Expand Down Expand Up @@ -182,41 +183,45 @@ def _get_gridspecs(user_model, variable_info):
if "options" in spec:
variables[name] = spec["options"]
else:
variables[name] = GridSpec(
grid_info = {k: v for k, v in spec.items() if k != "grid_type"}
variables[name] = ContinuousGridSpec(
kind=spec["grid_type"],
specs={k: v for k, v in spec.items() if k != "grid_type"},
info=ContinuousGridInfo(**grid_info),
)

order = variable_info.index.tolist()
return {k: variables[k] for k in order}


def _get_grids(gridspecs, variable_info):
def _get_grids(
gridspecs: dict[str, GridSpec],
variable_info: pd.DataFrame,
) -> dict[str, Array]:
"""Create a dictionary of grids for each variable in the model.
Args:
gridspecs (dict): Dictionary containing all variables of the model. The keys
are the names of the variables. The values describe which values the
variable can take. For discrete variables these are the options. For
continuous variables this is information about how to build the grids.
variable_info (pandas.DataFrame): A table with information about all
variables in the model. The index contains the name of a model variable.
The columns are booleans that are True if the variable has the
corresponding property. The columns are: is_state, is_choice, is_continuous,
is_discrete, is_sparse, is_dense.
gridspecs: Dictionary containing all variables of the model. The keys are the
names of the variables. The values describe which values the variable can
take. For discrete variables these are the options (jnp.array). For
continuous variables this is information about how to build the grids
(ContinuousGridSpec).
variable_info: A table with information about all variables in the model. The
index contains the name of a model variable. The columns are booleans that
are True if the variable has the corresponding property. The columns are:
is_state, is_choice, is_continuous, is_discrete, is_sparse, is_dense.
Returns:
dict: Dictionary containing all variables of the model. The keys are
the names of the variables. The values are the grids.
"""
grids = {}
for name, grid_info in gridspecs.items():
if variable_info.loc[name, "is_discrete"]:
grids[name] = jnp.array(grid_info)
for name, grid_spec in gridspecs.items():
if isinstance(grid_spec, ContinuousGridSpec):
build_grid = getattr(grids_module, grid_spec.kind)
grids[name] = build_grid(**grid_spec.info._asdict())
else:
func = getattr(grids_module, grid_info.kind)
grids[name] = func(**grid_info.specs)
grids[name] = jnp.array(grid_spec)

order = variable_info.index.tolist()
return {k: grids[k] for k in order}
Expand Down
4 changes: 4 additions & 0 deletions src/lcm/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from jax import Array

# Many JAX functions are designed to work with scalar numerical values. This also
# includes zero dimensional jax arrays.
Scalar = int | float | Array


class SegmentInfo(TypedDict):
"""Information on segments which is passed to `jax.ops.segment_max`.
Expand Down
Loading

0 comments on commit 20db0ec

Please sign in to comment.