Skip to content

Commit

Permalink
Move source name into Source class.
Browse files Browse the repository at this point in the history
This will be useful for:
- removing passing both dynamic_slice and dynamic_source_slice into `get_value`.
- allowing a source to look up its own registered model funcs once the registry is complete.

PiperOrigin-RevId: 703467877
  • Loading branch information
Nush395 authored and Torax team committed Dec 6, 2024
1 parent d429409 commit 63569ba
Show file tree
Hide file tree
Showing 18 changed files with 175 additions and 156 deletions.
44 changes: 27 additions & 17 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,20 @@ def test_source_formula_config_has_time_dependent_params(self):
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources={
electron_density_sources.GAS_PUFF_SOURCE_NAME: (
electron_density_sources.GasPuffSource.SOURCE_NAME: (
electron_density_sources.GasPuffRuntimeParams(
puff_decay_length={0.0: 0.0, 1.0: 4.0},
S_puff_tot={0.0: 0.0, 1.0: 5.0},
)
),
electron_density_sources.PELLET_SOURCE_NAME: (
electron_density_sources.PelletSource.SOURCE_NAME: (
electron_density_sources.PelletRuntimeParams(
pellet_width={0.0: 0.0, 1.0: 1.0},
pellet_deposition_location={0.0: 0.0, 1.0: 2.0},
S_pellet_tot={0.0: 0.0, 1.0: 3.0},
)
),
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME: (
electron_density_sources.GenericParticleSource.SOURCE_NAME: (
electron_density_sources.GenericParticleSourceRuntimeParams(
particle_width={0.0: 0.0, 1.0: 6.0},
deposition_location={0.0: 0.0, 1.0: 7.0},
Expand All @@ -210,12 +210,14 @@ def test_source_formula_config_has_time_dependent_params(self):
)(
t=0.5,
)
pellet_source = dcs.sources[electron_density_sources.PELLET_SOURCE_NAME]
pellet_source = dcs.sources[
electron_density_sources.PelletSource.SOURCE_NAME
]
gas_puff_source = dcs.sources[
electron_density_sources.GAS_PUFF_SOURCE_NAME
electron_density_sources.GasPuffSource.SOURCE_NAME
]
generic_particle_source = dcs.sources[
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME
electron_density_sources.GenericParticleSource.SOURCE_NAME
]
assert isinstance(
pellet_source,
Expand Down Expand Up @@ -247,7 +249,7 @@ def test_source_formula_config_has_time_dependent_params(self):
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources={
electron_density_sources.GAS_PUFF_SOURCE_NAME: (
electron_density_sources.GasPuffSource.SOURCE_NAME: (
sources_params_lib.RuntimeParams(
formula=formula_config.Exponential(
total={0.0: 0.0, 1.0: 1.0},
Expand All @@ -262,7 +264,7 @@ def test_source_formula_config_has_time_dependent_params(self):
t=0.25,
)
gas_puff_source = dcs.sources[
electron_density_sources.GAS_PUFF_SOURCE_NAME
electron_density_sources.GasPuffSource.SOURCE_NAME
]
assert isinstance(
gas_puff_source.formula,
Expand All @@ -277,7 +279,7 @@ def test_source_formula_config_has_time_dependent_params(self):
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources={
electron_density_sources.GAS_PUFF_SOURCE_NAME: (
electron_density_sources.GasPuffSource.SOURCE_NAME: (
sources_params_lib.RuntimeParams(
formula=formula_config.Gaussian(
total={0.0: 0.0, 1.0: 1.0},
Expand All @@ -292,7 +294,7 @@ def test_source_formula_config_has_time_dependent_params(self):
t=0.25,
)
gas_puff_source = dcs.sources[
electron_density_sources.GAS_PUFF_SOURCE_NAME
electron_density_sources.GasPuffSource.SOURCE_NAME
]
assert isinstance(gas_puff_source.formula, formula_config.DynamicGaussian)
np.testing.assert_allclose(gas_puff_source.formula.total, 0.25)
Expand All @@ -306,7 +308,7 @@ def test_wext_in_dynamic_runtime_params_cannot_be_negative(self):
runtime_params=runtime_params,
transport=transport_params_lib.RuntimeParams(),
sources={
generic_current_source.SOURCE_NAME: (
generic_current_source.GenericCurrentSource.SOURCE_NAME: (
generic_current_source.RuntimeParams(wext={0.0: 1.0, 1.0: -1.0})
),
},
Expand All @@ -317,7 +319,9 @@ def test_wext_in_dynamic_runtime_params_cannot_be_negative(self):
dcs = dcs_provider(
t=0.0,
)
generic_current = dcs.sources[generic_current_source.SOURCE_NAME]
generic_current = dcs.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
]
assert isinstance(
generic_current, generic_current_source.DynamicRuntimeParams
)
Expand All @@ -326,7 +330,9 @@ def test_wext_in_dynamic_runtime_params_cannot_be_negative(self):
dcs = dcs_provider(
t=0.5,
)
generic_current = dcs.sources[generic_current_source.SOURCE_NAME]
generic_current = dcs.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
]
assert isinstance(
generic_current, generic_current_source.DynamicRuntimeParams
)
Expand Down Expand Up @@ -501,7 +507,7 @@ def test_update_dynamic_slice_provider_updates_sources(
runtime_params = general_runtime_params.GeneralRuntimeParams()
source_models_builder = default_sources.get_default_sources_builder()
source_models_builder.runtime_params[
generic_current_source.SOURCE_NAME
generic_current_source.GenericCurrentSource.SOURCE_NAME
].Iext = 1.0
geo = geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
Expand All @@ -517,7 +523,7 @@ def test_update_dynamic_slice_provider_updates_sources(

# Update an interpolated variable.
source_models_builder.runtime_params[
generic_current_source.SOURCE_NAME
generic_current_source.GenericCurrentSource.SOURCE_NAME
].Iext = 2.0

# Check pre-update that nothing has changed.
Expand All @@ -526,7 +532,9 @@ def test_update_dynamic_slice_provider_updates_sources(
)
for key in source_models_builder.runtime_params.keys():
self.assertIn(key, dcs.sources)
generic_current = dcs.sources[generic_current_source.SOURCE_NAME]
generic_current = dcs.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
]
assert isinstance(
generic_current, generic_current_source.DynamicRuntimeParams
)
Expand All @@ -543,7 +551,9 @@ def test_update_dynamic_slice_provider_updates_sources(
)
for key in source_models_builder.runtime_params.keys():
self.assertIn(key, dcs.sources)
generic_current = dcs.sources[generic_current_source.SOURCE_NAME]
generic_current = dcs.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
]
assert isinstance(
generic_current, generic_current_source.DynamicRuntimeParams
)
Expand Down
4 changes: 2 additions & 2 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _prescribe_currents_no_bootstrap(
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
generic_current_source.SOURCE_NAME
generic_current_source.GenericCurrentSource.SOURCE_NAME
],
geo=geo,
core_profiles=core_profiles,
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def _get_jtot_hires(
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
generic_current_source.SOURCE_NAME
generic_current_source.GenericCurrentSource.SOURCE_NAME
],
geo=geo,
)
Expand Down
7 changes: 3 additions & 4 deletions torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import ClassVar

import chex
import jax
Expand All @@ -34,9 +35,6 @@
from torax.sources import source_profiles


SOURCE_NAME = 'j_bootstrap'


@dataclasses.dataclass(kw_only=True)
class RuntimeParams(runtime_params_lib.RuntimeParams):
"""Configuration parameters for the bootstrap current source."""
Expand Down Expand Up @@ -91,6 +89,7 @@ class BootstrapCurrentSource(source.Source):
- bootstrap current (on cell and face grids)
- total integrated bootstrap current
"""
SOURCE_NAME: ClassVar[str] = 'j_bootstrap'

@property
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]:
Expand Down Expand Up @@ -167,7 +166,7 @@ def get_source_profile_for_affected_core_profile(
) -> jax.Array:
return jnp.where(
affected_core_profile in self.affected_core_profiles_ints,
profile[SOURCE_NAME],
profile[self.SOURCE_NAME],
jnp.zeros_like(geo.rho),
)

Expand Down
5 changes: 2 additions & 3 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""Bremsstrahlung heat sink for electron heat equation.."""

import dataclasses
from typing import ClassVar

import chex
import jax
Expand All @@ -31,9 +32,6 @@
from torax.sources import source_models


SOURCE_NAME = 'bremsstrahlung_heat_sink'


@dataclasses.dataclass(kw_only=True)
class RuntimeParams(runtime_params_lib.RuntimeParams):
use_relativistic_correction: bool = False
Expand Down Expand Up @@ -151,6 +149,7 @@ def bremsstrahlung_model_func(
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class BremsstrahlungHeatSink(source.Source):
"""Brehmsstrahlung heat sink for electron heat equation."""
SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink'
model_func: source.SourceProfileFunction = bremsstrahlung_model_func

@property
Expand Down
5 changes: 2 additions & 3 deletions torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import ClassVar

import chex
import jax
Expand All @@ -36,8 +37,6 @@
runtime_params_lib.interpolated_param.InterpolatedVarTimeRhoInput
)

SOURCE_NAME = "electron_cyclotron_source"


@dataclasses.dataclass(kw_only=True)
class RuntimeParams(runtime_params_lib.RuntimeParams):
Expand Down Expand Up @@ -187,7 +186,7 @@ def _get_ec_output_shape(geo: geometry.Geometry) -> tuple[int, ...]:
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ElectronCyclotronSource(source.Source):
"""Electron cyclotron source for the Te and Psi equations."""

SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source"
model_func: source.SourceProfileFunction = _calc_heating_and_current

@property
Expand Down
13 changes: 4 additions & 9 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import ClassVar

import chex
import jax
Expand Down Expand Up @@ -100,12 +101,10 @@ def _calc_puff_source(
)


GAS_PUFF_SOURCE_NAME = 'gas_puff_source'


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class GasPuffSource(source.Source):
"""Gas puff source for the ne equation."""
SOURCE_NAME: ClassVar[str] = 'gas_puff_source'
formula: source.SourceProfileFunction = _calc_puff_source

@property
Expand Down Expand Up @@ -194,12 +193,10 @@ def _calc_generic_particle_source(
)


GENERIC_PARTICLE_SOURCE_NAME = 'generic_particle_source'


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class GenericParticleSource(source.Source):
"""Neutral-beam injection source for the ne equation."""
SOURCE_NAME: ClassVar[str] = 'generic_particle_source'
formula: source.SourceProfileFunction = _calc_generic_particle_source

@property
Expand Down Expand Up @@ -277,12 +274,10 @@ def _calc_pellet_source(
)


PELLET_SOURCE_NAME = 'pellet_source'


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class PelletSource(source.Source):
"""Pellet source for the ne equation."""
SOURCE_NAME: ClassVar[str] = 'pellet_source'
formula: source.SourceProfileFunction = _calc_pellet_source

@property
Expand Down
6 changes: 2 additions & 4 deletions torax/sources/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Optional
from typing import ClassVar, Optional

import jax
from jax import numpy as jnp
Expand All @@ -30,9 +30,6 @@
from torax.sources import source


SOURCE_NAME = 'fusion_heat_source'


def calc_fusion(
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
Expand Down Expand Up @@ -147,6 +144,7 @@ def fusion_heat_model_func(
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class FusionHeatSource(source.Source):
"""Fusion heat source for both ion and electron heat."""
SOURCE_NAME: ClassVar[str] = 'fusion_heat_source'
model_func: source.SourceProfileFunction = fusion_heat_model_func

@property
Expand Down
6 changes: 2 additions & 4 deletions torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Optional
from typing import ClassVar, Optional

import chex
import jax
Expand All @@ -34,9 +34,6 @@
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from typing_extensions import override


SOURCE_NAME = 'generic_current_source'
# pylint: disable=invalid-name


Expand Down Expand Up @@ -236,6 +233,7 @@ def _calculate_Iext(
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class GenericCurrentSource(source.Source):
"""A generic current density source profile."""
SOURCE_NAME: ClassVar[str] = 'generic_current_source'
formula: source.SourceProfileFunction = _calculate_generic_current_face
hires_formula: source.SourceProfileFunction = _calculate_generic_current_hires

Expand Down
6 changes: 2 additions & 4 deletions torax/sources/generic_ion_el_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Optional
from typing import ClassVar, Optional

import chex
import jax
Expand All @@ -30,9 +30,6 @@
from torax.sources import formulas
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source


SOURCE_NAME = 'generic_ion_el_heat_source'
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
Expand Down Expand Up @@ -151,6 +148,7 @@ def _default_formula(
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class GenericIonElectronHeatSource(source.Source):
"""Generic heat source for both ion and electron heat."""
SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source'
formula: source.SourceProfileFunction = _default_formula

@property
Expand Down
Loading

0 comments on commit 63569ba

Please sign in to comment.