Skip to content

Commit

Permalink
Merge pull request HajimeKawahara#457 from ykawashima/comp_petit
Browse files Browse the repository at this point in the history
Scripts for the comparison with petitRADTRANS
  • Loading branch information
HajimeKawahara authored Jan 11, 2024
2 parents c184a32 + b3e0e3c commit 7468045
Show file tree
Hide file tree
Showing 3 changed files with 654 additions and 0 deletions.
218 changes: 218 additions & 0 deletions tests/integration/comparison/twostream/comparison_petitRADTRANS_CIA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from jax import config
config.update("jax_enable_x64", True)

import numpy as np
import jax.numpy as jnp
import math
import os

from exojax.atm.atmprof import pressure_layer_logspace
from exojax.utils.grids import wavenumber_grid
from exojax.spec.multimol import MultiMol
from exojax.spec import contdb
from exojax.spec.layeropacity import layer_optical_depth, layer_optical_depth_CIA
from exojax.spec import planck
from exojax.spec.rtransfer import rtrun_emis_pure_absorption

from petitRADTRANS import Radtrans
from exojax.spec import molinfo
import petitRADTRANS.nat_cst as nc

from exojax.utils.instfunc import resolution_to_gaussian_std
from exojax.utils.grids import velocity_grid
from exojax.spec import response



def run_exojax(path_data, ld_min, ld_max, mols, db, T0, alpha, logg, logvmr):
Parr, dParr, k = pressure_layer_logspace(log_pressure_top=-3., nlayer=200)
ONEARR = np.ones_like(Parr)


R = 900000.
ld_min = ld_min - 5.
ld_max = ld_max + 5.
nu_min = 1.0e8 / ld_max
nu_max = 1.0e8 / ld_min
Nx = math.ceil(R * np.log(nu_max / nu_min)) + 1 # ueki
Nx = math.ceil(Nx / 2.) * 2 # make even

nus, wav, res = wavenumber_grid(ld_min, ld_max, Nx, unit="AA", xsmode="premodit")
nus = [nus]
wav = [wav]
res = [res]


mul = MultiMol(molmulti=[mols], dbmulti=[db], database_root_path=path_data)
multimdb = mul.multimdb(nus, crit=1.e-30, Ttyp=1000.)
multiopa = mul.multiopa_premodit(multimdb, nus, auto_trange=[500.,1500.], dit_grid_resolution=1.0)

cdbH2H2 = []
cdbH2He = []
for k in range(len(nus)):
cdbH2H2.append(contdb.CdbCIA(os.path.join(path_data, 'H2-H2_2011.cia'), nus[k]))
cdbH2He.append(contdb.CdbCIA(os.path.join(path_data, 'H2-He_2011.cia'), nus[k]))

molmass_list, molmassH2, molmassHe = mul.molmass()


def frun(T0, alpha, logg, logvmr):
Tarr = T0 * (Parr)**alpha
Tarr = np.clip(Tarr, 500, None)

g = 10.**logg # cgs

vmr = jnp.power(10., jnp.array(logvmr))
vmrH2 = (1. - jnp.sum(vmr)) * 6./7.
vmrHe = (1. - jnp.sum(vmr)) * 1./7.
mmw = jnp.sum(vmr*jnp.array(molmass_list)) + vmrH2*molmassH2 + vmrHe*molmassHe
mmr = jnp.multiply(vmr, jnp.array(molmass_list)) / mmw

mu = []
for k in range(len(nus)):
dtaum = []
for i in range(len(mul.masked_molmulti[k])):
xsm = multiopa[k][i].xsmatrix(Tarr, Parr)
xsm = jnp.abs(xsm)
dtaum.append(layer_optical_depth(dParr, xsm, mmr[mul.mols_num[k][i]]*ONEARR, molmass_list[mul.mols_num[k][i]], g))

dtau = sum(dtaum)

if(len(cdbH2H2[k].nucia) > 0):
dtaucH2H2 = layer_optical_depth_CIA(nus[k], Tarr, Parr, dParr, vmrH2, vmrH2, mmw, g, cdbH2H2[k].nucia, cdbH2H2[k].tcia, cdbH2H2[k].logac)
dtau = dtau + dtaucH2H2
if(len(cdbH2He[k].nucia) > 0):
dtaucH2He = layer_optical_depth_CIA(nus[k], Tarr, Parr, dParr, vmrH2, vmrHe, mmw, g, cdbH2He[k].nucia, cdbH2He[k].tcia, cdbH2He[k].logac)
dtau = dtau + dtaucH2He

sourcef = planck.piBarr(Tarr, nus[k])
F0 = rtrun_emis_pure_absorption(dtau, sourcef)

mu.append(F0)

return mu

return nus[0], frun(T0, alpha, logg, logvmr)[0]



def run_petit(ld_min, ld_max, mols, mols_exojax, T0, alpha, logg, logvmr):
atmosphere = Radtrans(line_species = mols,
rayleigh_species = ['H2', 'He'],
continuum_opacities = ['H2-H2', 'H2-He'],
wlen_bords_micron = [ld_min*1e-4, ld_max*1e-4],
mode = 'lbl')

pressures = np.logspace(-10, 2, 130)
atmosphere.setup_opa_structure(pressures)
temperature = T0*(pressures)**alpha
temperature = np.clip(temperature, 500, None)


molmass_list = []
for i in range(len(mols_exojax)):
molmass_list.append(molinfo.molmass(mols_exojax[i]))
molmassH2=molinfo.molmass("H2")
molmassHe=molinfo.molmass("He", db_HIT=False)

vmr = jnp.power(10., jnp.array(logvmr))
vmrH2 = (1. - jnp.sum(vmr)) * 6./7.
vmrHe = (1. - jnp.sum(vmr)) * 1./7.
mmw = jnp.sum(vmr*jnp.array(molmass_list)) + vmrH2*molmassH2 + vmrHe*molmassHe
mmr = jnp.multiply(vmr, jnp.array(molmass_list)) / mmw
mmrH2 = vmrH2 * molmassH2 / mmw
mmrHe = vmrHe * molmassHe / mmw

mass_fractions = {}
mass_fractions['H2'] = mmrH2 * np.ones_like(temperature)
mass_fractions['He'] = mmrHe * np.ones_like(temperature)
for i in range(len(mols)):
mass_fractions[mols[i]] = mmr[i] * np.ones_like(temperature)


MMW = mmw * np.ones_like(temperature)
gravity = 1e1**(logg)

atmosphere.calc_flux(temperature, mass_fractions, gravity, MMW)

ld = nc.c/atmosphere.freq/1e-4 # [um]
nus = 1.0e4 / ld # [cm^{-1}]
f = atmosphere.flux #[erg cm^{-2} s^{-1} Hz^{-1}]
f = f * nc.c #[erg/s/cm^2/cm^{-1}]

nus = nus[::-1]
f = f[::-1]

return nus, f



path_data = "/home/kawashima/database"
ld_min = 23000.
ld_max = 24000.
mols_exojax = ['CO']
db_exojax = ['ExoMol']
mols_petit = ['CO_all_iso']

T0 = 995.56
alpha = 0.09
logg = 5.01
logvmr = [-3.06]

nus1, f1 = run_exojax(path_data, ld_min, ld_max, mols_exojax, db_exojax, T0, alpha, logg, logvmr)
nus2, f2 = run_petit(ld_min, ld_max, mols_petit, mols_exojax, T0, alpha, logg, logvmr)



Rinst = 70000.
nu_min = 1.0e8 / ld_max
nu_max = 1.0e8 / ld_min
Nx = math.ceil(Rinst * np.log(nu_max / nu_min)) + 1 # ueki
Nx = math.ceil(Nx / 2.) * 2 # make even
nusd, wav, res = wavenumber_grid(ld_min, ld_max, Nx, unit="AA", xsmode="premodit")

beta_inst = resolution_to_gaussian_std(Rinst)
res_calc = 900000.
vsini_max = 100.0
vr_array = velocity_grid(res_calc, vsini_max)

f1_obs = response.ipgauss_sampling(nusd, nus1, f1, beta_inst, 0., vr_array)
if len(nus2) % 2 == 1:
nus2 = nus2[1:]
f2 = f2[1:]
f2_obs = response.ipgauss_sampling(nusd, nus2, f2, beta_inst, 0., vr_array)



import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(40,8), gridspec_kw={'height_ratios': [3, 1]})
plt.subplots_adjust(hspace=0)
ax1.plot(1.0e8/nusd, f1_obs, alpha=0.5, label="ExoJAX")
ax1.plot(1.0e8/nusd, f2_obs, alpha=0.5, label="petitRADTRANS")

ax2.plot(1.0e8/nusd, f1_obs - f2_obs, "+", color="black", markersize=5, alpha=0.2)

ax1.set_ylabel("Flux [erg/s/cm$\mathrm{^2}$/cm$\mathrm{^{-1}}$]", fontsize=15)
ax2.set_xlabel("Wavelength [$\AA$]", fontsize=15)
ax2.set_ylabel("Residual", fontsize=15)

ax1.set_xlim(np.min(1.0e8/nusd), np.max(1.0e8/nusd))
ax2.set_xlim(np.min(1.0e8/nusd), np.max(1.0e8/nusd))

ax1.xaxis.set_ticks_position('both')
ax1.yaxis.set_ticks_position('both')
ax2.xaxis.set_ticks_position('both')
ax2.yaxis.set_ticks_position('both')

ax1.xaxis.set_minor_locator(AutoMinorLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax2.xaxis.set_minor_locator(AutoMinorLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())

ax2.patch.set_alpha(0)
ax1.tick_params(labelbottom=False, labeltop=True)

ax1.legend()
#plt.show()
plt.savefig("output/CIA_R70000.png", bbox_inches='tight')
Loading

0 comments on commit 7468045

Please sign in to comment.