forked from HajimeKawahara/exojax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request HajimeKawahara#457 from ykawashima/comp_petit
Scripts for the comparison with petitRADTRANS
- Loading branch information
Showing
3 changed files
with
654 additions
and
0 deletions.
There are no files selected for viewing
218 changes: 218 additions & 0 deletions
218
tests/integration/comparison/twostream/comparison_petitRADTRANS_CIA.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
from jax import config | ||
config.update("jax_enable_x64", True) | ||
|
||
import numpy as np | ||
import jax.numpy as jnp | ||
import math | ||
import os | ||
|
||
from exojax.atm.atmprof import pressure_layer_logspace | ||
from exojax.utils.grids import wavenumber_grid | ||
from exojax.spec.multimol import MultiMol | ||
from exojax.spec import contdb | ||
from exojax.spec.layeropacity import layer_optical_depth, layer_optical_depth_CIA | ||
from exojax.spec import planck | ||
from exojax.spec.rtransfer import rtrun_emis_pure_absorption | ||
|
||
from petitRADTRANS import Radtrans | ||
from exojax.spec import molinfo | ||
import petitRADTRANS.nat_cst as nc | ||
|
||
from exojax.utils.instfunc import resolution_to_gaussian_std | ||
from exojax.utils.grids import velocity_grid | ||
from exojax.spec import response | ||
|
||
|
||
|
||
def run_exojax(path_data, ld_min, ld_max, mols, db, T0, alpha, logg, logvmr): | ||
Parr, dParr, k = pressure_layer_logspace(log_pressure_top=-3., nlayer=200) | ||
ONEARR = np.ones_like(Parr) | ||
|
||
|
||
R = 900000. | ||
ld_min = ld_min - 5. | ||
ld_max = ld_max + 5. | ||
nu_min = 1.0e8 / ld_max | ||
nu_max = 1.0e8 / ld_min | ||
Nx = math.ceil(R * np.log(nu_max / nu_min)) + 1 # ueki | ||
Nx = math.ceil(Nx / 2.) * 2 # make even | ||
|
||
nus, wav, res = wavenumber_grid(ld_min, ld_max, Nx, unit="AA", xsmode="premodit") | ||
nus = [nus] | ||
wav = [wav] | ||
res = [res] | ||
|
||
|
||
mul = MultiMol(molmulti=[mols], dbmulti=[db], database_root_path=path_data) | ||
multimdb = mul.multimdb(nus, crit=1.e-30, Ttyp=1000.) | ||
multiopa = mul.multiopa_premodit(multimdb, nus, auto_trange=[500.,1500.], dit_grid_resolution=1.0) | ||
|
||
cdbH2H2 = [] | ||
cdbH2He = [] | ||
for k in range(len(nus)): | ||
cdbH2H2.append(contdb.CdbCIA(os.path.join(path_data, 'H2-H2_2011.cia'), nus[k])) | ||
cdbH2He.append(contdb.CdbCIA(os.path.join(path_data, 'H2-He_2011.cia'), nus[k])) | ||
|
||
molmass_list, molmassH2, molmassHe = mul.molmass() | ||
|
||
|
||
def frun(T0, alpha, logg, logvmr): | ||
Tarr = T0 * (Parr)**alpha | ||
Tarr = np.clip(Tarr, 500, None) | ||
|
||
g = 10.**logg # cgs | ||
|
||
vmr = jnp.power(10., jnp.array(logvmr)) | ||
vmrH2 = (1. - jnp.sum(vmr)) * 6./7. | ||
vmrHe = (1. - jnp.sum(vmr)) * 1./7. | ||
mmw = jnp.sum(vmr*jnp.array(molmass_list)) + vmrH2*molmassH2 + vmrHe*molmassHe | ||
mmr = jnp.multiply(vmr, jnp.array(molmass_list)) / mmw | ||
|
||
mu = [] | ||
for k in range(len(nus)): | ||
dtaum = [] | ||
for i in range(len(mul.masked_molmulti[k])): | ||
xsm = multiopa[k][i].xsmatrix(Tarr, Parr) | ||
xsm = jnp.abs(xsm) | ||
dtaum.append(layer_optical_depth(dParr, xsm, mmr[mul.mols_num[k][i]]*ONEARR, molmass_list[mul.mols_num[k][i]], g)) | ||
|
||
dtau = sum(dtaum) | ||
|
||
if(len(cdbH2H2[k].nucia) > 0): | ||
dtaucH2H2 = layer_optical_depth_CIA(nus[k], Tarr, Parr, dParr, vmrH2, vmrH2, mmw, g, cdbH2H2[k].nucia, cdbH2H2[k].tcia, cdbH2H2[k].logac) | ||
dtau = dtau + dtaucH2H2 | ||
if(len(cdbH2He[k].nucia) > 0): | ||
dtaucH2He = layer_optical_depth_CIA(nus[k], Tarr, Parr, dParr, vmrH2, vmrHe, mmw, g, cdbH2He[k].nucia, cdbH2He[k].tcia, cdbH2He[k].logac) | ||
dtau = dtau + dtaucH2He | ||
|
||
sourcef = planck.piBarr(Tarr, nus[k]) | ||
F0 = rtrun_emis_pure_absorption(dtau, sourcef) | ||
|
||
mu.append(F0) | ||
|
||
return mu | ||
|
||
return nus[0], frun(T0, alpha, logg, logvmr)[0] | ||
|
||
|
||
|
||
def run_petit(ld_min, ld_max, mols, mols_exojax, T0, alpha, logg, logvmr): | ||
atmosphere = Radtrans(line_species = mols, | ||
rayleigh_species = ['H2', 'He'], | ||
continuum_opacities = ['H2-H2', 'H2-He'], | ||
wlen_bords_micron = [ld_min*1e-4, ld_max*1e-4], | ||
mode = 'lbl') | ||
|
||
pressures = np.logspace(-10, 2, 130) | ||
atmosphere.setup_opa_structure(pressures) | ||
temperature = T0*(pressures)**alpha | ||
temperature = np.clip(temperature, 500, None) | ||
|
||
|
||
molmass_list = [] | ||
for i in range(len(mols_exojax)): | ||
molmass_list.append(molinfo.molmass(mols_exojax[i])) | ||
molmassH2=molinfo.molmass("H2") | ||
molmassHe=molinfo.molmass("He", db_HIT=False) | ||
|
||
vmr = jnp.power(10., jnp.array(logvmr)) | ||
vmrH2 = (1. - jnp.sum(vmr)) * 6./7. | ||
vmrHe = (1. - jnp.sum(vmr)) * 1./7. | ||
mmw = jnp.sum(vmr*jnp.array(molmass_list)) + vmrH2*molmassH2 + vmrHe*molmassHe | ||
mmr = jnp.multiply(vmr, jnp.array(molmass_list)) / mmw | ||
mmrH2 = vmrH2 * molmassH2 / mmw | ||
mmrHe = vmrHe * molmassHe / mmw | ||
|
||
mass_fractions = {} | ||
mass_fractions['H2'] = mmrH2 * np.ones_like(temperature) | ||
mass_fractions['He'] = mmrHe * np.ones_like(temperature) | ||
for i in range(len(mols)): | ||
mass_fractions[mols[i]] = mmr[i] * np.ones_like(temperature) | ||
|
||
|
||
MMW = mmw * np.ones_like(temperature) | ||
gravity = 1e1**(logg) | ||
|
||
atmosphere.calc_flux(temperature, mass_fractions, gravity, MMW) | ||
|
||
ld = nc.c/atmosphere.freq/1e-4 # [um] | ||
nus = 1.0e4 / ld # [cm^{-1}] | ||
f = atmosphere.flux #[erg cm^{-2} s^{-1} Hz^{-1}] | ||
f = f * nc.c #[erg/s/cm^2/cm^{-1}] | ||
|
||
nus = nus[::-1] | ||
f = f[::-1] | ||
|
||
return nus, f | ||
|
||
|
||
|
||
path_data = "/home/kawashima/database" | ||
ld_min = 23000. | ||
ld_max = 24000. | ||
mols_exojax = ['CO'] | ||
db_exojax = ['ExoMol'] | ||
mols_petit = ['CO_all_iso'] | ||
|
||
T0 = 995.56 | ||
alpha = 0.09 | ||
logg = 5.01 | ||
logvmr = [-3.06] | ||
|
||
nus1, f1 = run_exojax(path_data, ld_min, ld_max, mols_exojax, db_exojax, T0, alpha, logg, logvmr) | ||
nus2, f2 = run_petit(ld_min, ld_max, mols_petit, mols_exojax, T0, alpha, logg, logvmr) | ||
|
||
|
||
|
||
Rinst = 70000. | ||
nu_min = 1.0e8 / ld_max | ||
nu_max = 1.0e8 / ld_min | ||
Nx = math.ceil(Rinst * np.log(nu_max / nu_min)) + 1 # ueki | ||
Nx = math.ceil(Nx / 2.) * 2 # make even | ||
nusd, wav, res = wavenumber_grid(ld_min, ld_max, Nx, unit="AA", xsmode="premodit") | ||
|
||
beta_inst = resolution_to_gaussian_std(Rinst) | ||
res_calc = 900000. | ||
vsini_max = 100.0 | ||
vr_array = velocity_grid(res_calc, vsini_max) | ||
|
||
f1_obs = response.ipgauss_sampling(nusd, nus1, f1, beta_inst, 0., vr_array) | ||
if len(nus2) % 2 == 1: | ||
nus2 = nus2[1:] | ||
f2 = f2[1:] | ||
f2_obs = response.ipgauss_sampling(nusd, nus2, f2, beta_inst, 0., vr_array) | ||
|
||
|
||
|
||
import matplotlib.pyplot as plt | ||
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator) | ||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(40,8), gridspec_kw={'height_ratios': [3, 1]}) | ||
plt.subplots_adjust(hspace=0) | ||
ax1.plot(1.0e8/nusd, f1_obs, alpha=0.5, label="ExoJAX") | ||
ax1.plot(1.0e8/nusd, f2_obs, alpha=0.5, label="petitRADTRANS") | ||
|
||
ax2.plot(1.0e8/nusd, f1_obs - f2_obs, "+", color="black", markersize=5, alpha=0.2) | ||
|
||
ax1.set_ylabel("Flux [erg/s/cm$\mathrm{^2}$/cm$\mathrm{^{-1}}$]", fontsize=15) | ||
ax2.set_xlabel("Wavelength [$\AA$]", fontsize=15) | ||
ax2.set_ylabel("Residual", fontsize=15) | ||
|
||
ax1.set_xlim(np.min(1.0e8/nusd), np.max(1.0e8/nusd)) | ||
ax2.set_xlim(np.min(1.0e8/nusd), np.max(1.0e8/nusd)) | ||
|
||
ax1.xaxis.set_ticks_position('both') | ||
ax1.yaxis.set_ticks_position('both') | ||
ax2.xaxis.set_ticks_position('both') | ||
ax2.yaxis.set_ticks_position('both') | ||
|
||
ax1.xaxis.set_minor_locator(AutoMinorLocator()) | ||
ax1.yaxis.set_minor_locator(AutoMinorLocator()) | ||
ax2.xaxis.set_minor_locator(AutoMinorLocator()) | ||
ax2.yaxis.set_minor_locator(AutoMinorLocator()) | ||
|
||
ax2.patch.set_alpha(0) | ||
ax1.tick_params(labelbottom=False, labeltop=True) | ||
|
||
ax1.legend() | ||
#plt.show() | ||
plt.savefig("output/CIA_R70000.png", bbox_inches='tight') |
Oops, something went wrong.