Skip to content

Commit

Permalink
Up coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
mkelley committed Dec 7, 2023
1 parent c32cc79 commit b8bf96d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 24 deletions.
45 changes: 30 additions & 15 deletions sbpy/activity/dust/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class State:
t : `~astropy.time.Time`
Time, a scalar or shape = (N,).
frame : `~astropy.coordinates.BaseCoordinateFrame` class or string, optional
frame : `~astropy.coordinates.BaseCoordinateFrame` class or string,
optional
Coordinate frame for ``r`` and ``v``. Defaults to
`~astropy.coordinates.HeliocentricEclipticIAU76` if given as ``None``.
Expand Down Expand Up @@ -104,7 +105,12 @@ def __init__(
raise ValueError("Mismatch between lengths of vectors.")

def __repr__(self) -> str:
return f"<{type(self).__name__} ({self.frame}):\n r\n {self.r}\n v\n {self.v}\n t\n {self.t}>"
return (
f"<{type(self).__name__} ({self.frame}):\n"
+ f" r\n {self.r}\n"
+ f" v\n {self.v}\n"
+ f" t\n {self.t}>"
)

def __len__(self):
"""Number of state vectors in this object."""
Expand Down Expand Up @@ -271,7 +277,11 @@ def transform_to(self, frame: FrameType) -> StateType:

return State.from_skycoord(self.to_skycoord().transform_to(frame))

def observe(self, target: StateType, frame: Optional[FrameType] = None) -> SkyCoord:
def observe(
self,
target: StateType,
frame: Optional[FrameType] = None,
) -> SkyCoord:
"""Project a target's position on to the sky.
Expand Down Expand Up @@ -319,7 +329,11 @@ def from_states(cls, states: Iterable[StateType]) -> StateType:

r: np.ndarray = np.array([state.r for state in states])
v: np.ndarray = np.array([state.v for state in states])
t: Time = Time([state.t.tdb.et for state in states], scale="tdb", format="et")
t: Time = Time(
[state.t.tdb.et for state in states],
scale="tdb",
format="et",
)

return State(r, v, t, frame=list(frames)[0])

Expand All @@ -331,8 +345,8 @@ def from_skycoord(cls, coords: SkyCoord) -> StateType:
Parameters
----------
coords: ~astropy.coordinates.SkyCoord
The object state. Must have position and velocity, ``obstime``, and
be convertible to cartesian (3D) coordinates.
The object state. Must have position and velocity, ``obstime``,
and be convertible to cartesian (3D) coordinates.
"""

Expand All @@ -346,7 +360,11 @@ def from_skycoord(cls, coords: SkyCoord) -> StateType:

@classmethod
@sbd.dataclass_input
def from_ephem(cls, eph: Ephem, frame: Optional[FrameType] = None) -> StateType:
def from_ephem(
cls,
eph: Ephem,
frame: Optional[FrameType] = None,
) -> StateType:
"""Initialize from an `~sbpy.data.Ephem` object.
Expand All @@ -355,8 +373,8 @@ def from_ephem(cls, eph: Ephem, frame: Optional[FrameType] = None) -> StateType:
eph : ~sbpy.data.ephem.Ephem
Ephemeris object, must have time, position, and velocity. Position
and velocity may be specified using ("x", "y", "z", "vx", "vy", and
"vz"), or ("ra", "dec", "Delta", "RA*cos(Dec)_rate", "Dec_rate", and
"deltadot").
"vz"), or ("ra", "dec", "Delta", "RA*cos(Dec)_rate", "Dec_rate",
and "deltadot").
frame : string or `~astropy.coordinates.BaseCoordinateFrame`, optional
Transform the coordinates into this reference frame.
Expand Down Expand Up @@ -486,8 +504,7 @@ def df_drv(cls, t: float, rv: np.ndarray, *args) -> np.ndarray:
:math:`df/dv`.
"""



def solve(
self,
initial: State,
Expand Down Expand Up @@ -599,8 +616,7 @@ def df_drv(cls, t: float, rv: np.ndarray, *args) -> np.ndarray:
r = rv[:3]
r2 = (r**2).sum()
r1 = np.sqrt(r2)
GM_r3 = cls._GM
GM_r5 = GM_r3 / r2
GM_r5 = GM_r3 = cls._GM / (r2 * r2 * r1)

# df_drv[i, j] = df_i/drv_j
df_drv = np.zeros((6, 6))
Expand Down Expand Up @@ -682,8 +698,7 @@ def df_drv(cls, t: float, rv: np.ndarray, beta: float, *args) -> np.ndarray:
r2 = (r**2).sum()
r1 = np.sqrt(r2)
r3 = r1 * r2
GM_r3 = cls._GM / r3 * (1 - beta)
GM_r5 = GM_r3 / r2
GM_r5 = cls._GM * (1 - beta) / (r2 * r3)

# df_drv[i, j] = df_i/drv_j
df_drv = np.zeros((6, 6))
Expand Down
57 changes: 48 additions & 9 deletions sbpy/activity/dust/tests/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SkyCoord,
HeliocentricEclipticIAU76,
)
import astropy.constants as const

from .... import time # for ephemeris time
from ....data import Ephem
Expand Down Expand Up @@ -539,6 +540,7 @@ def test_from_ephem(self):
with pytest.raises(ValueError):
State.from_ephem(incomplete)


def test_spice_prop2b():
"""Test case from SPICE NAIF toolkit prop2b, v2.2.0
Expand Down Expand Up @@ -582,6 +584,29 @@ class EarthGravity(SolarGravity):
assert np.allclose(final.v.value, [0, 0.04464, -0.04464], atol=0.00001)


class TestFreeExpansion:
def test(self):
r = [0, 1e6, 0]
v = [0, -1, 1]

solver = FreeExpansion()

initial = State(r, v, Time("2023-01-01"))
t_f = initial.t + 1e6 * u.s
final = solver.solve(initial, t_f)

assert np.allclose(final.r.value, [0, 0, 1e6], atol=2e-7)
assert np.allclose(final.v.value, [0, -1, 1])

solver = FreeExpansion(method="Radau")

initial = State(r, v, Time("2023-01-01"))
final = solver.solve(initial, t_f)

assert np.allclose(final.r.value, [0, 0, 1e6], atol=2e-7)
assert np.allclose(final.v.value, [0, -1, 1])


class TestSolarGravity:
@pytest.mark.parametrize("r1_au", ([0.3, 1, 3, 10, 30]))
def test_circular_orbit(self, r1_au):
Expand All @@ -607,20 +632,28 @@ def test_circular_orbit(self, r1_au):
assert np.allclose(final.r.value, initial.r.value)
assert np.allclose(final.v.value, initial.v.value)

t_f = initial.t + half_period * u.s
solver = SolarGravity(method="Radau")
final = solver.solve(initial, t_f)

class TestFreeExpansion:
def test(self):
r = [0, 1e6, 0]
v = [0, -1, 1]
assert np.allclose(final.r.value, -initial.r.value)
assert np.allclose(final.v.value, -initial.v.value)

solver = FreeExpansion()
def test_GM(self):
solver = SolarGravity()
assert u.isclose(solver.GM, const.G * const.M_sun)

def test_solverfailed(self):
r = [0, 1, 0] * u.au
v = [0, -1, 1] * u.km / u.s

# force a solution failure
solver = SolarGravity(rtol=np.nan)

initial = State(r, v, Time("2023-01-01"))
t_f = initial.t + 1e6 * u.s
final = solver.solve(initial, t_f)

assert np.allclose(final.r.value, [0, 0, 1e6], atol=2e-7)
assert np.allclose(final.v.value, [0, -1, 1])
with pytest.raises(SolverFailed):
solver.solve(initial, t_f)


class TestSolarGravityAndRadiationPressure:
Expand All @@ -644,3 +677,9 @@ class ReducedGravity(SolarGravity):

assert u.allclose(final1.r, final2.r)
assert u.allclose(final1.v, final2.v)

solver = SolarGravityAndRadiationPressure(method="Radau")
final2 = solver.solve(initial, t_f, beta)

assert u.allclose(final1.r, final2.r)
assert u.allclose(final1.v, final2.v)

0 comments on commit b8bf96d

Please sign in to comment.