Skip to content

Commit

Permalink
Merge pull request #574 from HajimeKawahara/nustitch
Browse files Browse the repository at this point in the history
unittests added for preparing nu stitch (this is not nu stitch itself)
  • Loading branch information
HajimeKawahara authored Feb 6, 2025
2 parents bf45c9f + 676ca78 commit 646143e
Show file tree
Hide file tree
Showing 16 changed files with 613 additions and 55 deletions.
4 changes: 2 additions & 2 deletions documents/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ It enables fully Bayesian inference for high-dispersion data, fitting line-by-li
-- from molecular/atomic databases to real spectra --
by integrating with Hamiltonian Monte Carlo - No U Turn Sampler (HMC-NUTS), Stochastic Variational Inference (SVI),
Nested Sampling, and other inference techniques available in modern probabilistic programming frameworks
such as NumPyro <https://github.com/pyro-ppl/numpyro>.
such as `NumPyro <https://github.com/pyro-ppl/numpyro>`_.
So, the notable features of ExoJAX are summarized as

- **HMC-NUTS, SVI, Nested Sampling, Gradient-based Optimizer available**
- **HMC-NUTS, SVI, Nested Sampling, Gradient-based Inference Techiques and Optimizers Available**
- **Easy to use the latest molecular/atomic data in** :doc:`userguide/api`, **and** :doc:`userguide/atomll`
- **A transparent open-source project; anyone who wants to participate can join the development!**

Expand Down
2 changes: 1 addition & 1 deletion src/exojax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = []

__version__ = '1.6'
__version__ = '2.0'
__uri__ = 'http://secondearths.sakura.ne.jp/exojax/'
__author__ = 'Hajime Kawahara and collaborators'
__email__ = 'divrot@gmail.com'
Expand Down
2 changes: 1 addition & 1 deletion src/exojax/signal/ola.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def olaconv(input_matrix_zeropad, fir_filter_zeropad, ndiv, div_length, filter_l


def overlap_and_add(ftarr, output_length, div_length):
"""Compute overlap and add
"""Compute overlap and add using scan
Args:
ftarr (jax.ndarray): filtered input matrix
Expand Down
4 changes: 2 additions & 2 deletions src/exojax/spec/lpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def voigt(nuvector, sigmaD, gammaL):


@jit
def vvoigt(numatrix, sigmaD, gammas):
def vvoigt(numatrix, sigmaD, gammaL):
"""Custom JVP version of vmaped voigt profile.
Args:
Expand All @@ -315,7 +315,7 @@ def vvoigt(numatrix, sigmaD, gammas):
Voigt profile vector in R^Nwav
"""
vmap_voigt = vmap(voigt, (0, 0, 0), 0)
return vmap_voigt(numatrix, sigmaD, gammas)
return vmap_voigt(numatrix, sigmaD, gammaL)


@jit
Expand Down
23 changes: 23 additions & 0 deletions src/exojax/spec/lpffilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import jax.numpy as jnp
from exojax.spec.lpf import voigt

def generate_lpffilter(nfilter, nsigmaD, ngammaL):
"""Generates LPF filter
Args:
nfilter (int): length of the wavenumber grid of lpffilter
nsigmaD (float): normalized gaussian standard deviation, resolution*betaT/nu betaT is the STD of Doppler broadening
ngammaL (float): normalized Lorentzian half width
Notes:
The filter structure is filter[1:M] = vkfilter[M+1:][::-1]m where M=N/2
filter[0] is the DC component, Nyquist component.
filter[M] is the Nyquist component.
Returns:
array: filter
"""
# dq is equivalent to resolution*jnp.log(nu_grid) - resolutiona*jnp.log(nu_grid[0]) (+ Nyquist)
dq = jnp.arange(0, nfilter + 1)
lpffilter_oneside = voigt(dq, nsigmaD, ngammaL)
return jnp.concatenate([lpffilter_oneside, lpffilter_oneside[1:-1][::-1]])
46 changes: 23 additions & 23 deletions src/exojax/spec/modit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,18 @@ def xsvector(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, S, nu_grid, ngammaL_gri
However, this will be changed when cufft fixes the 4GB limit.
Args:
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nsigmaD: normaized Gaussian STD
gammaL: Lorentzian half width (Nlines)
S: line strength (Nlines)
nu_grid: linear wavenumber grid
gammaL_grid: gammaL grid
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nsigmaD: normaized Gaussian STD
gammaL: Lorentzian half width (Nlines)
S: line strength (Nlines)
nu_grid: linear wavenumber grid
gammaL_grid: gammaL grid
Returns:
Cross section in the log nu grid
Cross section in the log nu grid
"""

log_ngammaL_grid = jnp.log(ngammaL_grid)
Expand All @@ -106,7 +106,7 @@ def xsvector(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, S, nu_grid, ngammaL_gri


@jit
def xsmatrix(cnu, indexnu, R, pmarray, nsigmaDl, ngammaLM, SijM, nu_grid, dgm_ngammaL):
def xsmatrix(cnu, indexnu, resolution, pmarray, nsigmaDl, ngammaLM, SijM, nu_grid, dgm_ngammaL):
"""Cross section matrix for xsvector (MODIT)
Notes:
Expand All @@ -116,19 +116,19 @@ def xsmatrix(cnu, indexnu, R, pmarray, nsigmaDl, ngammaLM, SijM, nu_grid, dgm_ng
Args:
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nu_lines: line center (Nlines)
nsigmaDl: normalized doppler sigma in layers in R^(Nlayer x 1)
ngammaLM: gamma factor matrix in R^(Nlayer x Nline)
SijM: line strength matrix in R^(Nlayer x Nline)
nu_grid: linear wavenumber grid
dgm_ngammaL: DIT Grid Matrix for normalized gammaL R^(Nlayer, NDITgrid)
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
resolution: spectral resolution, same as resolution (3rd return value) from utils.grids.wavenumber_grid
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nu_lines: line center [Nlines]
nsigmaDl: normalized doppler sigma in layers in [Nlayer, 1]
ngammaLM: gamma factor matrix in [Nlayer, Nline]
SijM: line strength matrix in [Nlayer, Nline]
nu_grid: linear wavenumber grid
dgm_ngammaL: DIT Grid Matrix for normalized gammaL [Nlayer, NDITgrid]
Return:
cross section matrix in R^(Nlayer x Nwav)
cross section matrix in [Nlayer x Nwav]
"""
NDITgrid = jnp.shape(dgm_ngammaL)[1]
Nline = len(cnu)
Expand All @@ -141,7 +141,7 @@ def fxs(x, arr):
Sij = arr[Nline + 1 : 2 * Nline + 1]
ngammaL_grid = arr[2 * Nline + 1 : 2 * Nline + NDITgrid + 1]
arr = xsvector(
cnu, indexnu, R, pmarray, nsigmaD, ngammaL, Sij, nu_grid, ngammaL_grid
cnu, indexnu, resolution, pmarray, nsigmaD, ngammaL, Sij, nu_grid, ngammaL_grid
)
return carry, arr

Expand Down
67 changes: 42 additions & 25 deletions src/exojax/spec/modit_scanfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
you should consider to modit_scanfft
"""

import jax.numpy as jnp
from jax import jit, vmap
from jax.lax import scan
from exojax.spec.ditkernel import fold_voigt_kernel_logst
from exojax.spec.lsd import inc2D_givenx

def calc_xsection_from_lsd_scanfft(Slsd, R, pmarray, nsigmaD, nu_grid,
log_ngammaL_grid):

def calc_xsection_from_lsd_scanfft(
Slsd, R, pmarray, nsigmaD, nu_grid, log_ngammaL_grid
):
"""Compute cross section from LSD in MODIT algorithm using scan+fft to avoid 4GB memory limit in fft (see #277)
Args:
Expand All @@ -27,8 +30,10 @@ def calc_xsection_from_lsd_scanfft(Slsd, R, pmarray, nsigmaD, nu_grid,
Cross section in the log nu grid
"""

# add buffer
Sbuf = jnp.vstack([Slsd, jnp.zeros_like(Slsd)])

# layer by layer fft
def f(i, x):
y = jnp.fft.rfft(x)
i = i + 1
Expand All @@ -37,25 +42,33 @@ def f(i, x):
nscan, fftval = scan(f, 0, Sbuf.T)
fftval = fftval.T
Ng_nu = len(nu_grid)
vk = fold_voigt_kernel_logst(jnp.fft.rfftfreq(2 * Ng_nu, 1),
jnp.log(nsigmaD), log_ngammaL_grid, Ng_nu,
pmarray)
fftvalsum = jnp.sum(fftval * vk, axis=(1, ))
return jnp.fft.irfft(fftvalsum)[:Ng_nu] * R / nu_grid

# filter kernel
vk = fold_voigt_kernel_logst(
jnp.fft.rfftfreq(2 * Ng_nu, 1),
jnp.log(nsigmaD),
log_ngammaL_grid,
Ng_nu,
pmarray,
)

# convolves
fftvalsum = jnp.sum(fftval * vk, axis=(1,))
return jnp.fft.irfft(fftvalsum)[:Ng_nu] * R / nu_grid


@jit
def xsvector_scanfft(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, S, nu_grid,
ngammaL_grid):
def xsvector_scanfft(
cnu, indexnu, R, pmarray, nsigmaD, ngammaL, S, nu_grid, ngammaL_grid
):
"""Cross section vector (MODIT scanfft)
Args:
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nsigmaD: normaized Gaussian STD
nsigmaD: normaized Gaussian STD
gammaL: Lorentzian half width (Nlines)
S: line strength (Nlines)
nu_grid: linear wavenumber grid
Expand All @@ -67,16 +80,17 @@ def xsvector_scanfft(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, S, nu_grid,

log_ngammaL_grid = jnp.log(ngammaL_grid)
lsd_array = jnp.zeros((len(nu_grid), len(ngammaL_grid)))
Slsd = inc2D_givenx(lsd_array, S, cnu, indexnu, jnp.log(ngammaL),
log_ngammaL_grid)
xs = calc_xsection_from_lsd_scanfft(Slsd, R, pmarray, nsigmaD, nu_grid,
log_ngammaL_grid)
Slsd = inc2D_givenx(lsd_array, S, cnu, indexnu, jnp.log(ngammaL), log_ngammaL_grid)
xs = calc_xsection_from_lsd_scanfft(
Slsd, R, pmarray, nsigmaD, nu_grid, log_ngammaL_grid
)
return xs


@jit
def xsmatrix_scanfft(cnu, indexnu, R, pmarray, nsigmaDl, ngammaLM, SijM, nu_grid,
dgm_ngammaL):
def xsmatrix_scanfft(
cnu, indexnu, R, pmarray, nsigmaDl, ngammaLM, SijM, nu_grid, dgm_ngammaL
):
"""Cross section matrix for xsvector (MODIT), scan+fft
Args:
Expand All @@ -101,20 +115,22 @@ def xsmatrix_scanfft(cnu, indexnu, R, pmarray, nsigmaDl, ngammaLM, SijM, nu_grid
def fxs(x, arr):
carry = 0.0
nsigmaD = arr[0:1]
ngammaL = arr[1:Nline + 1]
Sij = arr[Nline + 1:2 * Nline + 1]
ngammaL_grid = arr[2 * Nline + 1:2 * Nline + NDITgrid + 1]
arr = xsvector_scanfft(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, Sij,
nu_grid, ngammaL_grid)
ngammaL = arr[1 : Nline + 1]
Sij = arr[Nline + 1 : 2 * Nline + 1]
ngammaL_grid = arr[2 * Nline + 1 : 2 * Nline + NDITgrid + 1]
arr = xsvector_scanfft(
cnu, indexnu, R, pmarray, nsigmaD, ngammaL, Sij, nu_grid, ngammaL_grid
)
return carry, arr

val, xsm = scan(fxs, 0.0, Mat)
return xsm


@jit
def xsmatrix_vald_scanfft(cnuS, indexnuS, R, pmarray, nsigmaDlS, ngammaLMS, SijMS,
nu_grid, dgm_ngammaLS):
def xsmatrix_vald_scanfft(
cnuS, indexnuS, R, pmarray, nsigmaDlS, ngammaLMS, SijMS, nu_grid, dgm_ngammaLS
):
"""Cross section matrix for xsvector (MODIT) with scan+fft, for VALD lines (asdb)
Args:
Expand All @@ -131,7 +147,8 @@ def xsmatrix_vald_scanfft(cnuS, indexnuS, R, pmarray, nsigmaDlS, ngammaLMS, SijM
Return:
xsmS: cross section matrix [N_species x N_layer x N_wav]
"""
xsmS = jit(vmap(xsmatrix_scanfft, (0, 0, None, None, 0, 0, 0, None, 0)))(\
cnuS, indexnuS, R, pmarray, nsigmaDlS, ngammaLMS, SijMS, nu_grid, dgm_ngammaLS)
xsmS = jit(vmap(xsmatrix_scanfft, (0, 0, None, None, 0, 0, 0, None, 0)))(
cnuS, indexnuS, R, pmarray, nsigmaDlS, ngammaLMS, SijMS, nu_grid, dgm_ngammaLS
)
xsmS = jnp.abs(xsmS)
return xsmS
2 changes: 1 addition & 1 deletion src/exojax/spec/opacalc.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def __init__(
warnings.warn("Tarr_list/Parr are needed for xsmatrix.", UserWarning)

def __eq__(self, other):
"""eq method for OpaDirect, definied by comparing all the attributes and important status
"""eq method for OpaModit, definied by comparing all the attributes and important status
Args:
other (_type_): _description_
Expand Down
Loading

0 comments on commit 646143e

Please sign in to comment.