Skip to content

Commit

Permalink
Merge pull request #568 from HajimeKawahara/opart_reflect_emis
Browse files Browse the repository at this point in the history
Opart reflect emis
  • Loading branch information
HajimeKawahara authored Jan 28, 2025
2 parents b01747b + 62ecdf8 commit 2f593f6
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 12 deletions.
163 changes: 156 additions & 7 deletions src/exojax/spec/opart.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class OpartReflectPure(ArtCommon):
This class computes the outgoing flux of the atmosphere with reflection, no emission from atmospheric layers nor surface.
Radiative transfer scheme: flux-based two-stream method, using flux-adding treatment, Toon-type hemispheric mean approximation
"""

def __init__(self, opalayer, pressure_top=1.0e-8, pressure_btm=1.0e2, nlayer=100):
Expand Down Expand Up @@ -146,10 +146,7 @@ def update_layer(self, carry_rs, params):
"""
Rphat_prev, Sphat_prev = carry_rs

# - no source term
# temparature = params[0]
# source_vector = piB(temparature, self.nu_grid)
# -------------------------------------------------
# no source term
source_vector = jnp.zeros_like(self.nu_grid)
dtau, single_scattering_albedo, asymmetric_parameter = self.opalayer(params)
trans_coeff_i, scat_coeff_i, pihatB_i, _, _, _ = setrt_toonhm(
Expand All @@ -167,7 +164,7 @@ def __call__(
self,
layer_params,
layer_update_function,
refectivity_bottom,
reflectivity_bottom,
incoming_flux,
):
"""computes outgoing flux
Expand All @@ -183,14 +180,166 @@ def __call__(
"""
# rs_bottom = (refectivity_bottom, source_bottom)
source_bottom = jnp.zeros_like(self.nu_grid)
rs_bottom = [refectivity_bottom, source_bottom]
rs_bottom = [reflectivity_bottom, source_bottom]
rs, _ = scan(layer_update_function, rs_bottom, layer_params)
return rs[0] * incoming_flux + rs[1]

def run(self, opalayer, layer_params, flbl):
return self(opalayer, layer_params, flbl)


class OpartReflectEmis(ArtCommon):
"""Opart verision of ArtReflectEmis.
This class computes the outgoing flux of the atmosphere with reflection, with emission from atmospheric layers.
Radiative transfer scheme: flux-based two-stream method, using flux-adding treatment, Toon-type hemispheric mean approximation
"""

def __init__(self, opalayer, pressure_top=1.0e-8, pressure_btm=1.0e2, nlayer=100):
"""Initialization of OpartReflectPure
Args:
opalayer (class): user defined class, needs to define self.nu_grid
pressure_top (float, optional): top pressure in bar. Defaults to 1.0e-8.
pressure_btm (float, optional): bottom pressure in bar. Defaults to 1.0e2.
nlayer (int, optional): the number of the atmospheric layers. Defaults to 100.
"""
super().__init__(pressure_top, pressure_btm, nlayer, opalayer.nu_grid)
self.opalayer = opalayer
self.nu_grid = self.opalayer.nu_grid

def update_layer(self, carry_rs, params):
"""updates the layer opacity and effective reflectivity (Rphat) and source (Sphat)
Args:
carry_rs (list): carry for the Rphat and Sphat
params (list): layer parameters for this layer
Returns:
list: updated carry_rs
"""
Rphat_prev, Sphat_prev = carry_rs

# blackbody source term in the layers
temparature = params[0]
source_vector = piB(temparature, self.nu_grid)
# -------------------------------------------------
dtau, single_scattering_albedo, asymmetric_parameter = self.opalayer(params)
trans_coeff_i, scat_coeff_i, pihatB_i, _, _, _ = setrt_toonhm(
dtau, single_scattering_albedo, asymmetric_parameter, source_vector
)
denom = 1.0 - scat_coeff_i * Rphat_prev
Sphat_each = (
pihatB_i + trans_coeff_i * (Sphat_prev + pihatB_i * Rphat_prev) / denom
)
Rphat_each = scat_coeff_i + trans_coeff_i**2 * Rphat_prev / denom
carry_rs = [Rphat_each, Sphat_each]
return carry_rs

def __call__(
self,
layer_params,
layer_update_function,
source_bottom,
reflectivity_bottom,
incoming_flux,
):
"""computes outgoing flux
Args:
layer_params (list): user defined layer parameters, layer_params[0] should be temperature array
layer_update_function (method):
source_bottom (array): source at the bottom [Nnus]
reflectivity_bottom (array): reflectivity at the bottom [Nnus]
incoming_flux (array): incoming flux [Nnus]
Returns:
array: flux [Nnus]
"""
rs_bottom = [reflectivity_bottom, source_bottom]
rs, _ = scan(layer_update_function, rs_bottom, layer_params)
return rs[0] * incoming_flux + rs[1]

def run(self, opalayer, layer_params, flbl):
return self(opalayer, layer_params, flbl)


class OpartEmisScat(ArtCommon):
"""Opart verision of ArtEmisScat.
This class computes the outgoing emission flux of the atmosphere with scattering in the atmospheric layers.
Radiative transfer scheme: flux-based two-stream method, using flux-adding treatment, Toon-type hemispheric mean approximation
"""

def __init__(self, opalayer, pressure_top=1.0e-8, pressure_btm=1.0e2, nlayer=100):
"""Initialization of OpartReflectPure
Args:
opalayer (class): user defined class, needs to define self.nu_grid
pressure_top (float, optional): top pressure in bar. Defaults to 1.0e-8.
pressure_btm (float, optional): bottom pressure in bar. Defaults to 1.0e2.
nlayer (int, optional): the number of the atmospheric layers. Defaults to 100.
"""
super().__init__(pressure_top, pressure_btm, nlayer, opalayer.nu_grid)
self.opalayer = opalayer
self.nu_grid = self.opalayer.nu_grid

def update_layer(self, carry_rs, params):
"""updates the layer opacity and effective reflectivity (Rphat) and source (Sphat)
Args:
carry_rs (list): carry for the Rphat and Sphat
params (list): layer parameters for this layer
Returns:
list: updated carry_rs
"""
Rphat_prev, Sphat_prev = carry_rs

# blackbody source term in the layers
temparature = params[0]
source_vector = piB(temparature, self.nu_grid)
# -------------------------------------------------
dtau, single_scattering_albedo, asymmetric_parameter = self.opalayer(params)
trans_coeff_i, scat_coeff_i, pihatB_i, _, _, _ = setrt_toonhm(
dtau, single_scattering_albedo, asymmetric_parameter, source_vector
)
denom = 1.0 - scat_coeff_i * Rphat_prev
Sphat_each = (
pihatB_i + trans_coeff_i * (Sphat_prev + pihatB_i * Rphat_prev) / denom
)
Rphat_each = scat_coeff_i + trans_coeff_i**2 * Rphat_prev / denom
carry_rs = [Rphat_each, Sphat_each]
return carry_rs

def __call__(
self,
layer_params,
layer_update_function,
):
"""computes outgoing flux
Args:
layer_params (list): user defined layer parameters, layer_params[0] should be temperature array
layer_update_function (method):
Returns:
array: flux [Nnus]
"""
# no reflection at the bottom
reflectivity_bottom = jnp.zeros_like(self.nu_grid)
# no source term at the bottom
source_bottom = jnp.zeros_like(self.nu_grid)
rs_bottom = [reflectivity_bottom, source_bottom]
rs, _ = scan(layer_update_function, rs_bottom, layer_params)
return rs[1]

def run(self, opalayer, layer_params, flbl):
return self(opalayer, layer_params, flbl)


if __name__ == "__main__":

from exojax.spec.opacalc import OpaPremodit
Expand Down
65 changes: 65 additions & 0 deletions tests/integration/unittests_long/opart/opart_emis_scat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""checks the forward model of the opart spectrum
"""

import pytest
import numpy as np
import jax.numpy as jnp
from exojax.test.emulate_mdb import mock_wavenumber_grid
from exojax.test.emulate_mdb import mock_mdbExomol
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.opart import OpartEmisScat
from exojax.spec.layeropacity import single_layer_optical_depth

from jax import config

config.update("jax_enable_x64", True)


def test_forward_emis_scat_opart():
class OpaLayer:
# user defined class, needs to define self.nugrid
def __init__(self):
self.nu_grid, self.wav, self.resolution = mock_wavenumber_grid()
self.gravity = 2478.57
self.mdb_co = mock_mdbExomol()

self.opa_co = OpaPremodit(
self.mdb_co, self.nu_grid, auto_trange=[400.0, 1500.0]
)

def __call__(self, params):
temperature, pressure, dP, mixing_ratio = params
xsv_co = self.opa_co.xsvector(temperature, pressure)
dtau_co = single_layer_optical_depth(
dP, xsv_co, mixing_ratio, self.mdb_co.molmass, self.gravity
)
single_scattering_albedo = jnp.ones_like(dtau_co) * 0.3
asymmetric_parameter = jnp.ones_like(dtau_co) * 0.01

return dtau_co, single_scattering_albedo, asymmetric_parameter

opalayer = OpaLayer()
opart = OpartEmisScat(opalayer, pressure_top=1.0e-6, pressure_btm=1.0e0, nlayer=200)

def layer_update_function(carry, params):
carry = opart.update_layer(carry, params)
return carry, None

temperature = opart.powerlaw_temperature(1300.0, 0.1)
mixing_ratio = opart.constant_mmr_profile(0.0003)
layer_params = [temperature, opart.pressure, opart.dParr, mixing_ratio]
flux = opart(layer_params, layer_update_function)
print(np.mean(flux))
ref = 515245.12625256577 # 1/28 2025
assert np.mean(flux) == pytest.approx(ref)
plot = False
if plot:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(opalayer.nu_grid, flux)
plt.savefig("forward_opart_emis_scat.png")


if __name__ == "__main__":
test_forward_emis_scat_opart()
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""checks the forward model of the opart spectrum
"""
import pytest
import numpy as np
import jax.numpy as jnp
from exojax.test.data import TESTDATA_CO_EXOMOL_PREMODIT_REFLECTION_REF
from exojax.test.emulate_mdb import mock_wavenumber_grid
from exojax.test.emulate_mdb import mock_mdbExomol
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.opart import OpartReflectEmis
from exojax.spec.layeropacity import single_layer_optical_depth

from jax import config
config.update("jax_enable_x64", True)

def test_forward_reflection_emis_opart():
class OpaLayer:
# user defined class, needs to define self.nugrid
def __init__(self):
self.nu_grid, self.wav, self.resolution = mock_wavenumber_grid()
self.gravity = 2478.57
self.mdb_co = mock_mdbExomol()

self.opa_co = OpaPremodit(
self.mdb_co, self.nu_grid, auto_trange=[400.0, 1500.0]
)

def __call__(self, params):
temperature, pressure, dP, mixing_ratio = params
xsv_co = self.opa_co.xsvector(temperature, pressure)
dtau_co = single_layer_optical_depth(
dP, xsv_co, mixing_ratio, self.mdb_co.molmass, self.gravity
)
single_scattering_albedo = jnp.ones_like(dtau_co) * 0.3
asymmetric_parameter = jnp.ones_like(dtau_co) * 0.01

return dtau_co, single_scattering_albedo, asymmetric_parameter

opalayer = OpaLayer()
opart = OpartReflectEmis(
opalayer, pressure_top=1.0e-6, pressure_btm=1.0e0, nlayer=200
)

def layer_update_function(carry, params):
carry = opart.update_layer(carry, params)
return carry, None

temperature = opart.powerlaw_temperature(1300.0, 0.1)
mixing_ratio = opart.constant_mmr_profile(0.0003)
layer_params = [temperature, opart.pressure, opart.dParr, mixing_ratio]

albedo = 1.0
constant_incoming_flux = 1.e6
incoming_flux = constant_incoming_flux*jnp.ones_like(opalayer.nu_grid)
reflectivity_surface = albedo * jnp.ones_like(opalayer.nu_grid)

constant_surface_flux = 1.e5
source_bottom = constant_surface_flux*jnp.ones_like(opalayer.nu_grid)

flux = opart(
layer_params, layer_update_function, source_bottom, reflectivity_surface, incoming_flux
)
ref = 1764352.3124300546 #2021 1/28
print(np.mean(flux))
assert np.mean(flux) == pytest.approx(ref)
plot = False
if plot:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(opalayer.nu_grid, flux)
plt.savefig("forward_opart_reflect_emis.png")


if __name__ == "__main__":
test_forward_reflection_emis_opart()
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""checks the forward model of the opart spectrum
"""
import pkg_resources
from importlib.resources import files
import pandas as pd
import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -53,14 +53,13 @@ def layer_update_function(carry, params):
albedo = 1.0
incoming_flux = jnp.ones_like(opalayer.nu_grid)
reflectivity_surface = albedo * jnp.ones_like(opalayer.nu_grid)
source_bottom = jnp.zeros_like(opalayer.nu_grid)


flux = opart(
layer_params, layer_update_function, reflectivity_surface, incoming_flux
)

filename = pkg_resources.resource_filename('exojax',
'data/testdata/' + TESTDATA_CO_EXOMOL_PREMODIT_REFLECTION_REF)
filename = files('exojax').joinpath('data/testdata/' + TESTDATA_CO_EXOMOL_PREMODIT_REFLECTION_REF)

dat = pd.read_csv(filename, delimiter=",", names=("nus", "flux"))

residual = np.abs(flux / dat["flux"].values - 1.0)
Expand Down

0 comments on commit 2f593f6

Please sign in to comment.