Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Martin/dist components #5

Merged
merged 14 commits into from
Jan 27, 2025
8 changes: 4 additions & 4 deletions jax_ib/base/IBM_Force.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def immersed_boundary_force(velocity_field: tuple[GridVariable, GridVariable],
velocity_field, particle, dirac_delta_approx, surface_fn,t, dt)
forcex += per_object_forcex
forcey += per_object_forcey
return (GridVariable(GridArray(forcex,velocity_field[0].offset,velocity_field[0].grid), velocity_field[0].bc),
GridVariable(GridArray(forcey,velocity_field[1].offset,velocity_field[1].grid), velocity_field[1].bc))
return (GridVariable(GridArray(forcex,velocity_field[0].offset,velocity_field[0].grid, velocity_field[0].width), velocity_field[0].bc),
GridVariable(GridArray(forcey,velocity_field[1].offset,velocity_field[1].grid, velocity_field[1].width), velocity_field[1].bc))


def immersed_boundary_force_per_particle_deprecated(
Expand Down Expand Up @@ -325,6 +325,6 @@ def immersed_boundary_force_deprecated(velocity_field: tuple[GridVariable, GridV
surface_fn,dx_dt,domega_dt,rotation,dt)
forcex += per_object_forcex
forcey += per_object_forcey
return (GridVariable(GridArray(forcex,velocity_field[0].offset,velocity_field[0].grid), velocity_field[0].bc),
GridVariable(GridArray(forcey,velocity_field[1].offset,velocity_field[1].grid), velocity_field[1].bc))
return (GridVariable(GridArray(forcex,velocity_field[0].offset,velocity_field[0].grid, velocity_field[0].width), velocity_field[0].bc),
GridVariable(GridArray(forcey,velocity_field[1].offset,velocity_field[1].grid, velocity_field[1].width), velocity_field[1].bc))

2 changes: 1 addition & 1 deletion jax_ib/base/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def advect_van_leer(
# for negative velocity we simply need to shift the correction along v axis.
# Cast to GridVariable so that we can apply a shift() operation.
forward_correction_array = grids.GridVariable(
grids.GridArray(forward_correction, u.offset, u.grid), u.bc)
grids.GridArray(forward_correction, u.offset, u.grid, u.width), u.bc)
backward_correction_array = forward_correction_array.shift(+1, axis)
backward_correction = backward_correction_array.data
abs_velocity = abs(u.array)
Expand Down
8 changes: 8 additions & 0 deletions jax_ib/base/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ def laplacian_matrix(size: int, step: float) -> np.ndarray:
column[1] = column[-1] = 1 / step**2
return scipy.linalg.circulant(column)

def laplacian_column(size: int, step: float) -> np.ndarray:
"""Create 1D Laplacian operator matrix, with periodic BC."""
step = float(step) # mganahl: Grid now uses a jax.Array type for `step`
column = np.zeros(size)
column[0] = -2 / step**2
column[1] = column[-1] = 1 / step**2
return column

def laplacian_matrix_neumann(size: int, step: float) -> np.ndarray:
"""Create 1D Laplacian operator matrix, with homogeneous Neumann BC."""
column = np.zeros(size)
Expand Down
4 changes: 2 additions & 2 deletions jax_ib/base/boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _needs_pad_with_boundary_value():
else:
raise ValueError('invalid boundary type')

return GridArray(data, tuple(offset), u.grid)
return GridArray(data, tuple(offset), u.grid, u.width)

def _trim(
self,
Expand Down Expand Up @@ -303,7 +303,7 @@ def _trim(
data = lax.slice_in_dim(u.data, padding[0], limit_index, axis=axis)
offset = list(u.offset)
offset[axis] += padding[0]
return GridArray(data, tuple(offset), u.grid)
return GridArray(data, tuple(offset), u.grid, u.width)

def _trim_padding(self, u: grids.GridArray, axis=0):
"""Trim all padding from a GridArray.
Expand Down
24 changes: 20 additions & 4 deletions jax_ib/base/convolution_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,26 @@ def gaussian(x: jax.Array,mu: jax.Array,sigma:jax.Array)->float:


def mesh_convolve(field:GridVariable,
xp:jax.Array,
yp:jax.Array,
x:jax.Array,
y:jax.Array,
dirac_delta_approx: callable, axis_names:list[str]) -> jax.Array:
local_conv = convolve(field, xp, yp, dirac_delta_approx)
"""
Compute the convolution of sharded array `field` with 2d-dirac-delta functions located at `x, y`.
The convolution is computed for each pair `x[i], y[i]` in parallel.

Args:
field: GridVariable of the field
x, y: locations of the dirac-delta peaks
dirac_delta_approx: Function approximating a dirac-delta function in 1d.
Expected function signature is `dirac_delta_approx(x, X, dx)`, with
`x` a float, `X` a `jax.Array` of shape `field.data.shape`, and `dx`
a float.
axis_names: The names of the mapped axes of the device mesh.

Returns:
jax.Array: the convolution result.
"""
local_conv = convolve(field, x, y, dirac_delta_approx)
return jax.lax.psum(
jax.lax.psum(local_conv, axis_name = axis_names[0]),
axis_name = axis_names[1])
Expand Down Expand Up @@ -88,4 +104,4 @@ def foo_pmap(tree_arg):
mapped.append([xp[i*n:(i+1)*n],yp[i*n:(i+1)*n]])
arr = jnp.array(mapped)
U_deltas = jax.pmap(foo_pmap)(jnp.array(mapped))
return U_deltas.flatten()
return U_deltas.flatten()
41 changes: 41 additions & 0 deletions jax_ib/base/fast_diagonalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
from jax import lax
import jax.numpy as jnp
from jax_ib.base import fft
import numpy as np


Expand Down Expand Up @@ -257,3 +258,43 @@ def func(v):
return transform(func, operators, dtype, hermitian=hermitian,
circulant=circulant, implementation=implementation,
precision=precision)


def pseudo_poisson_inversion(
eigenvalues: jax.Array,
dtype: jnp.dtype,
axis_names:tuple[str],
cutoff: Optional[float] = None,
) -> Callable[[Array], Array]:
"""Invert a linear operator written as a sum of operators on each axis.

Args:
operators: forward linear operators as matrices, applied along each axis.
Each of these matrices is diagonalized.
dtype: dtype of the right-hand-side.
hermitian: whether or not all linear operator are Hermitian (i.e., symmetric
in the real valued case).
circulant: whether or not all linear operators are circulant.
implementation: how to implement fast diagonalization.
precision: numerical precision for matrix multplication. Only relevant on
TPUs.
cutoff: eigenvalues with absolute value smaller than this number are
discarded rather than being inverted. By default, uses 10 times floating
point epsilon.

Returns:
A function that computes the pseudo-inverse of the indicated operator.
"""
if cutoff is None:
cutoff = 10 * jnp.finfo(dtype).eps

def func(v):
return jnp.where(abs(v) > cutoff, 1 / v, 0)

"""Fast diagonalization by Fast Fourier Transform."""
diagonal = func(eigenvalues)
def apply(rhs: Array) -> Array:
return fft.ifft_2d(diagonal * fft.fft_2d(rhs, axis_names), axis_names).astype(dtype)
return apply


1 change: 1 addition & 0 deletions jax_ib/base/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _ifft1d_j(array: jax.Array, axis_name: str)-> jax.Array:

return _get_fft_j(axis_name, BWD)(array)


def fft(array: jax.Array, axis:int, axis_name:str):
"""
Compute the 1d-FFT of a 2d-array `array` along axis `axis` with name `axis_name`
Expand Down
8 changes: 4 additions & 4 deletions jax_ib/base/finite_differences.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def stencil_sum(*arrays: GridArray) -> GridArray:
# Actually passed: (iterable: Generator[Union[jax.interpreters.xla.DeviceArray, numpy.ndarray], Any, None])
result = sum(array.data for array in arrays) # type: ignore
grid = grids.consistent_grid(*arrays)
return grids.GridArray(result, offset, grid)
return grids.GridArray(result, offset, grid, arrays[0].width)


# incompatible with typing.overload
Expand All @@ -73,7 +73,7 @@ def central_difference(
...


def central_difference(u, axis=None):
def central_difference(u, axis=None)->GridArray:
"""Approximates grads with central differences."""
if axis is None:
axis = range(u.grid.ndim)
Expand All @@ -94,7 +94,7 @@ def backward_difference(
...


def backward_difference(u, axis=None):
def backward_difference(u, axis=None)->GridArray:
"""
First order finite-difference approximation of the backward gradient of `u`
"""
Expand All @@ -118,7 +118,7 @@ def forward_difference(
...


def forward_difference(u, axis=None):
def forward_difference(u:GridVariable, axis=None)->GridArray:
"""
First order finite-difference approximation of the forward gradient of `u`
"""
Expand Down
Loading