Skip to content

Commit

Permalink
Use Vloop_bound_right in compute_boundary_conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-brown committed Nov 26, 2024
1 parent 90098a5 commit 93fd0aa
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def _calculate_psi_grad_constraint_from_Ip_tot(
/ (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1])
)


def _psi_value_constraint_from_Vloop(
dt: jax.Array,
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
Expand All @@ -576,6 +577,7 @@ def _psi_value_constraint_from_Vloop(
+ dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt
)


def _init_psi_and_current(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: Geometry,
Expand Down Expand Up @@ -965,10 +967,15 @@ def compute_boundary_conditions(
right_face_constraint=jnp.array(nimp_bound_right),
),
'psi': dict(
right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot(
right_face_grad_constraint=(
_calculate_psi_grad_constraint_from_Ip_tot(
dynamic_runtime_params_slice_t,
geo,
),
)
if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right
is None
else None
),
right_face_constraint=(
_psi_value_constraint_from_Vloop(
dynamic_runtime_params_slice_t,
Expand All @@ -980,7 +987,7 @@ def compute_boundary_conditions(
else None
),
),
}
}


# pylint: disable=invalid-name
Expand Down

0 comments on commit 93fd0aa

Please sign in to comment.