Skip to content

Commit

Permalink
Merge pull request #576 from HajimeKawahara/remove_modit_old_xs
Browse files Browse the repository at this point in the history
 Remove modit deprecated calc_xsection_from_lsd, xsvector, xsmatrix
  • Loading branch information
HajimeKawahara authored Feb 6, 2025
2 parents 646143e + 028217c commit ee26954
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 362 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Cross Section for Many Lines using MODIT
========================================

Update: October 30/2022, Hajime Kawahara
Update: Febrary 7th/2025, Hajime Kawahara

We demonstarte the Modified Discrete Integral Transform (MODIT), which
is the modified version of DIT for exojax. MODIT uses the evenly-spaced
Expand Down Expand Up @@ -122,9 +122,9 @@ Let’s compute the cross section!

.. code:: ipython3
from exojax.spec.modit import xsvector
from exojax.spec.modit_scanfft import xsvector_scanfft
xs = xsvector(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, Sij, nus, ngammaL_grid)
xs = xsvector_scanfft(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, Sij, nus, ngammaL_grid)
Also, we here try the direct computation using LPF for the comparison
purpose
Expand Down Expand Up @@ -177,3 +177,4 @@ There is about 1 % deviation between LPF and MODIT.
.. image:: Cross_Section_using_Modified_Discrete_Integral_Transform_files/Cross_Section_using_Modified_Discrete_Integral_Transform_18_0.png


189 changes: 34 additions & 155 deletions src/exojax/spec/modit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import jax.numpy as jnp
from jax import jit, vmap
from jax.lax import scan
from exojax.spec.ditkernel import fold_voigt_kernel_logst
from exojax.spec.lsd import inc2D_givenx
from exojax.spec.set_ditgrid import minmax_ditgrid_matrix
from exojax.spec.set_ditgrid import precompute_modit_ditgrid_matrix

Expand All @@ -30,125 +28,6 @@
from exojax.utils.constants import Tref_original


def calc_xsection_from_lsd(Slsd, R, pmarray, nsigmaD, nu_grid, log_ngammaL_grid):
"""Compute cross section from LSD in MODIT algorithm
The original code is rundit_fold_logredst in `addit package <https://github.com/HajimeKawahara/addit>`_ ). MODIT folded voigt for ESLOG for reduced wavenumebr inputs (against the truncation error) for a constant normalized beta
Args:
Slsd: line shape density
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nsigmaD: normaized Gaussian STD
nu_grid: linear wavenumber grid
log_gammaL_grid: logarithm of gammaL grid
Note:
When you have the error such as:
"failed to initialize batched cufft plan with customized allocator:
Allocating 8000000160 bytes exceeds the memory limit of 4294967296 bytes."
consider to use moditscanfft.calc_xsection_from_lsd, instead.
Returns:
Cross section in the log nu grid
"""

Sbuf = jnp.vstack([Slsd, jnp.zeros_like(Slsd)])
fftval = jnp.fft.rfft(Sbuf, axis=0)
Ng_nu = len(nu_grid)
# -----------------------------------------------
# MODIT w/o new folding
# til_Voigt=voigt_kernel_logst(k, log_nstbeta,log_ngammaL_grid)
# til_Slsd = jnp.fft.rfft(Sbuf,axis=0)
# fftvalsum = jnp.sum(til_Slsd*til_Voigt,axis=(1,))
# return jnp.fft.irfft(fftvalsum)[:Ng_nu]*R/nu_grid
# -----------------------------------------------
vk = fold_voigt_kernel_logst(
jnp.fft.rfftfreq(2 * Ng_nu, 1),
jnp.log(nsigmaD),
log_ngammaL_grid,
Ng_nu,
pmarray,
)
fftvalsum = jnp.sum(fftval * vk, axis=(1,))
return jnp.fft.irfft(fftvalsum)[:Ng_nu] * R / nu_grid


@jit
def xsvector(cnu, indexnu, R, pmarray, nsigmaD, ngammaL, S, nu_grid, ngammaL_grid):
"""Cross section vector (MODIT)
Notes:
Currently due to #277, we recommend to use
modit_scanfft.xsvector_scanfft instead of xsvector.
However, this will be changed when cufft fixes the 4GB limit.
Args:
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nsigmaD: normaized Gaussian STD
gammaL: Lorentzian half width (Nlines)
S: line strength (Nlines)
nu_grid: linear wavenumber grid
gammaL_grid: gammaL grid
Returns:
Cross section in the log nu grid
"""

log_ngammaL_grid = jnp.log(ngammaL_grid)
lsd_array = jnp.zeros((len(nu_grid), len(ngammaL_grid)))
Slsd = inc2D_givenx(lsd_array, S, cnu, indexnu, jnp.log(ngammaL), log_ngammaL_grid)
xs = calc_xsection_from_lsd(Slsd, R, pmarray, nsigmaD, nu_grid, log_ngammaL_grid)
return xs


@jit
def xsmatrix(cnu, indexnu, resolution, pmarray, nsigmaDl, ngammaLM, SijM, nu_grid, dgm_ngammaL):
"""Cross section matrix for xsvector (MODIT)
Notes:
Currently due to #277, we recommend to use
modit_scanfft.xsmatrix_scanfft instead of xsmatrix.
However, this will be changed when cufft fixes the 4GB limit.
Args:
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
resolution: spectral resolution, same as resolution (3rd return value) from utils.grids.wavenumber_grid
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nu_lines: line center [Nlines]
nsigmaDl: normalized doppler sigma in layers in [Nlayer, 1]
ngammaLM: gamma factor matrix in [Nlayer, Nline]
SijM: line strength matrix in [Nlayer, Nline]
nu_grid: linear wavenumber grid
dgm_ngammaL: DIT Grid Matrix for normalized gammaL [Nlayer, NDITgrid]
Return:
cross section matrix in [Nlayer x Nwav]
"""
NDITgrid = jnp.shape(dgm_ngammaL)[1]
Nline = len(cnu)
Mat = jnp.hstack([nsigmaDl, ngammaLM, SijM, dgm_ngammaL])

def fxs(x, arr):
carry = 0.0
nsigmaD = arr[0:1]
ngammaL = arr[1 : Nline + 1]
Sij = arr[Nline + 1 : 2 * Nline + 1]
ngammaL_grid = arr[2 * Nline + 1 : 2 * Nline + NDITgrid + 1]
arr = xsvector(
cnu, indexnu, resolution, pmarray, nsigmaD, ngammaL, Sij, nu_grid, ngammaL_grid
)
return carry, arr

val, xsm = scan(fxs, 0.0, Mat)
return xsm


def exomol(mdb, Tarr, Parr, R, molmass):
"""compute molecular line information required for MODIT using Exomol mdb.
Expand Down Expand Up @@ -191,24 +70,24 @@ def set_ditgrid_matrix_exomol(mdb, fT, Parr, R, molmass, dit_grid_resolution, *k
"""Easy Setting of DIT Grid Matrix (dgm) using Exomol.
Args:
mdb: mdb instance
fT: function of temperature array
Parr: pressure array
R: spectral resolution
molmass: molecular mass
dit_grid_resolution: resolution of dgm
mdb: mdb instance
fT: function of temperature array
Parr: pressure array
R: spectral resolution
molmass: molecular mass
dit_grid_resolution: resolution of dgm
*kargs: arguments for fT
Returns:
DIT Grid Matrix (dgm) of normalized gammaL
DIT Grid Matrix (dgm) of normalized gammaL
Example:
>>> fT = lambda T0,alpha: T0[:,None]*(Parr[None,:]/Pref)**alpha[:,None]
>>> T0_test=np.array([1100.0,1500.0,1100.0,1500.0])
>>> alpha_test=np.array([0.2,0.2,0.05,0.05])
>>> dit_grid_resolution=0.2
>>> dgm_ngammaL=setdgm_exomol(mdbCH4,fT,Parr,R,molmassCH4,dit_grid_resolution,T0_test,alpha_test)
>>> fT = lambda T0,alpha: T0[:,None]*(Parr[None,:]/Pref)**alpha[:,None]
>>> T0_test=np.array([1100.0,1500.0,1100.0,1500.0])
>>> alpha_test=np.array([0.2,0.2,0.05,0.05])
>>> dit_grid_resolution=0.2
>>> dgm_ngammaL=setdgm_exomol(mdbCH4,fT,Parr,R,molmassCH4,dit_grid_resolution,T0_test,alpha_test)
"""
set_dgm_minmax = []
Tarr_list = fT(*kargs)
Expand All @@ -225,17 +104,17 @@ def hitran(mdb, Tarr, Parr, Pself, R, molmass):
"""compute molecular line information required for MODIT using HITRAN/HITEMP mdb.
Args:
mdb: mdb instance
Tarr: Temperature array
Parr: Pressure array
Pself: Partial pressure array
R: spectral resolution
molmass: molecular mass
mdb: mdb instance
Tarr: Temperature array
Parr: Pressure array
Pself: Partial pressure array
R: spectral resolution
molmass: molecular mass
Returns:
line intensity matrix,
normalized gammaL matrix,
normalized sigmaD matrix
line intensity matrix,
normalized gammaL matrix,
normalized sigmaD matrix
"""
qt = vmap(mdb.qr_interp_lines, (0, None))(Tarr, Tref_original)
SijM = jit(vmap(line_strength, (0, None, None, None, 0, None)))(
Expand Down Expand Up @@ -265,25 +144,25 @@ def set_ditgrid_matrix_hitran(
"""Easy Setting of DIT Grid Matrix (dgm) using HITRAN/HITEMP.
Args:
mdb: mdb instance
fT: function of temperature array
Parr: pressure array
Pself_ref: reference partial pressure array
R: spectral resolution
molmass: molecular mass
dit_grid_resolution: resolution of dgm
mdb: mdb instance
fT: function of temperature array
Parr: pressure array
Pself_ref: reference partial pressure array
R: spectral resolution
molmass: molecular mass
dit_grid_resolution: resolution of dgm
*kargs: arguments for fT
Returns:
DIT Grid Matrix (dgm) of normalized gammaL
DIT Grid Matrix (dgm) of normalized gammaL
Example:
>>> fT = lambda T0,alpha: T0[:,None]*(Parr[None,:]/Pref)**alpha[:,None]
>>> T0_test=np.array([1100.0,1500.0,1100.0,1500.0])
>>> alpha_test=np.array([0.2,0.2,0.05,0.05])
>>> dit_grid_resolution=0.2
>>> dgm_ngammaL=setdgm_hitran(mdbCH4,fT,Parr,Pself,R,molmassCH4,dit_grid_resolution,T0_test,alpha_test)
>>> fT = lambda T0,alpha: T0[:,None]*(Parr[None,:]/Pref)**alpha[:,None]
>>> T0_test=np.array([1100.0,1500.0,1100.0,1500.0])
>>> alpha_test=np.array([0.2,0.2,0.05,0.05])
>>> dit_grid_resolution=0.2
>>> dgm_ngammaL=setdgm_hitran(mdbCH4,fT,Parr,Pself,R,molmassCH4,dit_grid_resolution,T0_test,alpha_test)
"""
set_dgm_minmax = []
Tarr_list = fT(*kargs)
Expand Down
20 changes: 10 additions & 10 deletions src/exojax/spec/modit_scanfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ def xsvector_scanfft(
"""Cross section vector (MODIT scanfft)
Args:
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nsigmaD: normaized Gaussian STD
gammaL: Lorentzian half width (Nlines)
S: line strength (Nlines)
nu_grid: linear wavenumber grid
gammaL_grid: gammaL grid
cnu: contribution by npgetix for wavenumber
indexnu: index by npgetix for wavenumber
R: spectral resolution
pmarray: (+1,-1) array whose length of len(nu_grid)+1
nsigmaD: normaized Gaussian STD
gammaL: Lorentzian half width (Nlines)
S: line strength (Nlines)
nu_grid: linear wavenumber grid
gammaL_grid: gammaL grid
Returns:
Cross section in the log nu grid
Cross section in the log nu grid
"""

log_ngammaL_grid = jnp.log(ngammaL_grid)
Expand Down
6 changes: 3 additions & 3 deletions src/exojax/test/generate_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from exojax.test.data import TESTDATA_CO_HITEMP_MODIT_XS_REF
from exojax.test.data import TESTDATA_CO_HITEMP_MODIT_XS_REF_AIR
import numpy as np
from exojax.spec.modit import xsvector
from exojax.spec.modit_scanfft import xsvector_scanfft
from exojax.spec.hitran import line_strength
from exojax.spec.molinfo import molmass_isotope
from exojax.spec import normalized_doppler_sigma, gamma_natural
Expand Down Expand Up @@ -64,7 +64,7 @@ def gendata_xs_modit_exomol():
)

ngammaL_grid = ditgrid_log_interval(ngammaL, dit_grid_resolution=0.1)
xsv = xsvector(
xsv = xsvector_scanfft(
cont_nu, index_nu, R, pmarray, nsigmaD, ngammaL, Sij, nu_grid, ngammaL_grid
)

Expand Down Expand Up @@ -118,7 +118,7 @@ def gendata_xs_modit_hitemp(airmode=False):
cont_nu, index_nu, R, pmarray = init_modit(mdbCO.nu_lines, nu_grid)
ngammaL_grid = ditgrid_log_interval(ngammaL, dit_grid_resolution=0.1)

xsv = xsvector(
xsv = xsvector_scanfft(
cont_nu, index_nu, R, pmarray, nsigmaD, ngammaL, Sij, nu_grid, ngammaL_grid
)

Expand Down
4 changes: 2 additions & 2 deletions tests/benchmark/modit_bm_each.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import time
from exojax.spec.modit import xsvector
from exojax.spec.modit_scanfft import xsvector_scanfft
from exojax.spec.set_ditgrid import ditgrid_log_interval
from exojax.spec import initspec
import jax.numpy as jnp
Expand All @@ -18,7 +18,7 @@ def xs(Nline):
ngammaL = gammaL/(nu_lines/R)
ngammaL_grid = ditgrid_log_interval(ngammaL, dit_grid_resolution=0.1)
S = jnp.array(np.random.normal(size=Nline))
xsv = xsvector(cnu, indexnu, R, pmarray, nsigmaD,
xsv = xsvector_scanfft(cnu, indexnu, R, pmarray, nsigmaD,
ngammaL, S, nus, ngammaL_grid)
xsv.block_until_ready()
return True
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmark/modit_bm_wide.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import time
from exojax.spec.modit import xsvector
from exojax.spec.modit_scanfft import xsvector_scanfft
from exojax.spec.set_ditgrid import ditgrid_log_interval
import numpy as np
import jax.numpy as jnp
Expand All @@ -24,7 +24,7 @@ def xs(Nc, Nline=10000):
a = []
for i in range(0, Nc):
tsx = time.time()
xsv = xsvector(cnu, indexnu, R, pmarray, nsigmaD,
xsv = xsvector_scanfft(cnu, indexnu, R, pmarray, nsigmaD,
ngammaL, S, nus, ngammaL_grid)
xsv.block_until_ready()
tex = time.time()
Expand Down
Loading

0 comments on commit ee26954

Please sign in to comment.