Skip to content

Commit

Permalink
try using astropy to do frame conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mkelley committed Nov 28, 2023
1 parent 4184ce3 commit e785a13
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 55 deletions.
150 changes: 110 additions & 40 deletions sbpy/activity/dust/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

__all__ = [
"State",
"FreeExpansion",
"SolarGravity",
"SolarGravityAndRadiationPressure",
Expand Down Expand Up @@ -76,9 +77,7 @@ class State:
-----
State is internally stored in units of km, km / s, and TDB seconds past
J2000.0 epoch. The internal reference frame is
`~astropy.coordinates.HeliocentricEclipticIAU76` to match what JPL/Horizons
and the NAIF SPICE toolkit use.
J2000.0 epoch.
"""

Expand All @@ -93,45 +92,94 @@ def __init__(
self.v = u.Quantity(v, "km/s")
self.t = Time(t, format="et", scale="tdb")

if not (self.r.shape[0] == self.v.shape[0] == self.t.shape[0]):
if (self.r.shape != self.v.shape) or (len(self) != len(self.t)):
raise ValueError("Mismatch between lengths of vectors.")

self.frame = "heliocentriceclipticiau76" if frame is None else frame
if self.frame != "heliocentriceclipticiau76":
raise NotImplementedError
# use astropy to convert between reference frames
self.r = SkyCoord(
x=self.r[..., 0],
y=self.r[..., 1],
z=self.r[..., 2],
v_x=self.v[..., 0],
v_y=self.v[..., 1],
v_z=self.v[..., 2],
obstime=self.t,
frame="heliocentriceclipticiau76" if frame is None else frame,
representation_type="cartesian",
)

def __len__(self):
"""Number of state vectors in this object."""
return self._r.shape[0]
if self._r.ndim == 1:
return 1
else:
return self._r.shape[0]

def __getitem__(self, k: int) -> StateType:
"""Get the state at index ``k``."""
def __getitem__(self, k: Union[int, tuple, slice]) -> StateType:
"""Get the state(s) at ``k``."""
return State(self.r[k], self.v[k], self.t[k], frame=self.frame)

@property
def r(self) -> u.Quantity[u.km]:
"""Position vector in the internal reference frame."""
"""Position vector."""
return u.Quantity(self._r, u.km)

@r.setter
@u.quantity_input
def r(self, r: u.Quantity[u.m]):
self._r = r.to_value(u.km).reshape((-1, 3))
def r(self, r: u.Quantity[u.km]):
if r.ndim > 3 or r.shape[r.ndim - 1] != 3:
raise ValueError("Must have shape (3,) or (N, 3).")
self._r = r.to_value(u.km)

@property
def x(self) -> u.Quantity[u.km]:
"""x component of the position vector."""
return self.r[..., 0]

@property
def y(self) -> u.Quantity[u.km]:
"""y component of the position vector."""
return self.r[..., 1]

@property
def z(self) -> u.Quantity[u.km]:
"""z component of the position vector."""
return self.r[..., 2]

@property
def v(self) -> u.Quantity[u.km / u.s]:
"""Velocity vector in the internal reference frame."""
"""Velocity vector."""
return u.Quantity(self._v, u.km / u.s)

@v.setter
@u.quantity_input
def v(self, v: u.Quantity[u.m / u.s]):
self._v = v.to_value(u.km / u.s).reshape((-1, 3))
def v(self, v: u.Quantity[u.km / u.s]):
if v.ndim > 3 or v.shape[v.ndim - 1] != 3:
raise ValueError("Must have shape (3,) or (N, 3).")
self._v = v.to_value(u.km / u.s)

@property
def v_x(self) -> u.Quantity[u.km / u.s]:
"""x component of the velocity vector."""
return self.v[..., 0]

@property
def v_y(self) -> u.Quantity[u.km / u.s]:
"""y component of the velocity vector."""
return self.v[..., 1]

@property
def v_z(self) -> u.Quantity[u.km / u.s]:
"""z component of the velocity vector."""
return self.v[..., 2]

@property
def rv(self) -> np.ndarray:
"""Position in km, and velocity in km/s."""
return np.hstack((self._r, self._v))
if self._r.ndim == 1:
return np.r_[self._r, self._v]
else:
return np.hstack((self._r, self._v))

@property
def t(self) -> Time:
Expand All @@ -143,17 +191,33 @@ def t(self, t):
self._t = t.tdb.to_value("et").reshape((-1,))

@property
def coords(self) -> SkyCoord:
"""State as a `~astropy.coordinates.SkyCoords` object."""
def skycoord(self) -> SkyCoord:
"""State as a `~astropy.coordinates.SkyCoord` object."""
return SkyCoord(
x=self.x,
y=self.y,
z=self.z,
v_x=self.v_x,
v_y=self.v_y,
v_z=self.v_z,
obstime=self.t,
frame=self.frame,
representation_type="cartesian",
)

def observe(self, observer: StateType):
"""Observe and return a `~astropy.coordinates.SkyCoord` object."""

target: State = State.from_skycoord(self.skycoord.transform_to(observer.frame))
return SkyCoord(
x=self.r[:, 0],
y=self.r[:, 1],
z=self.r[:, 2],
v_x=self.v[:, 0],
v_y=self.v[:, 1],
v_z=self.v[:, 2],
x=target.x - observer.x,
y=target.y - observer.y,
z=target.z - observer.z,
v_x=target.v_x - observer.v_x,
v_y=target.v_y - observer.v_y,
v_z=target.v_z - observer.v_z,
obstime=self.t,
frame="heliocentriceclipticiau76",
frame=observer.frame,
representation_type="cartesian",
)

Expand All @@ -174,9 +238,9 @@ def from_states(cls, states: Iterable[StateType]) -> StateType:
if len(frames) != 1:
raise ValueError("The coordinate frames must be identical.")

r: np.ndarray = np.array([state.r[0] for state in states])
v: np.ndarray = np.array([state.v[0] for state in states])
t: Time = Time([state.t[0] for state in states])
r: np.ndarray = np.array([state.r for state in states])
v: np.ndarray = np.array([state.v for state in states])
t: Time = Time([state.t.tdb.et for state in states], scale="tdb", format="et")

return State(r, v, t, frame=list(frames)[0])

Expand All @@ -193,12 +257,18 @@ def from_skycoord(cls, coords: SkyCoord) -> StateType:
"""

out_coords = coords.transform_to("heliocentriceclipticiau76")
out_coords.representation_type = "cartesian"
r: u.Quantity = u.Quantity([out_coords.x, out_coords.y, out_coords.z])
v: u.Quantity = u.Quantity([out_coords.v_x, out_coords.v_y, out_coords.v_z])
t: Time = out_coords.obstime
return cls(r, v, t)
r: u.Quantity = u.Quantity(
[coords.cartesian.x, coords.cartesian.y, coords.cartesian.z]
).T
v: u.Quantity = u.Quantity(
[
coords.cartesian.differentials["s"].d_x,
coords.cartesian.differentials["s"].d_y,
coords.cartesian.differentials["s"].d_z,
]
).T
t: Time = coords.obstime
return cls(r, v, t, frame=coords.frame)

@classmethod
@sbd.dataclass_input
Expand Down Expand Up @@ -393,23 +463,23 @@ def solve(
"""

final = State([0, 0, 0], [0, 0, 0], t_f, frame=initial.frame)
jac_sparsity: np.ndarray = np.zeros((6, 6))
jac_sparsity[0, 3:] = 1
jac_sparsity[3:, :3] = 1
# jac_sparsity: np.ndarray = np.zeros((6, 6))
# jac_sparsity[0, 3:] = 1
# jac_sparsity[3:, :3] = 1

ivp_kwargs = dict(
rtol=1e-8,
atol=[1e-4, 1e-4, 1e-4, 1e-10, 1e-10, 1e-10],
jac=cls.df_drv,
jac_sparsity=jac_sparsity, # not used for all methods
# jac_sparsity=jac_sparsity, # not used for all methods
method="LSODA",
)
ivp_kwargs.update(kwargs)

result = solve_ivp(
cls.dx_dt,
(initial.t.et[0], final.t.et[0]),
initial.rv[0],
initial.rv,
args=(beta,),
**ivp_kwargs,
)
Expand Down
85 changes: 70 additions & 15 deletions sbpy/activity/dust/syndynes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@

import time
import logging
from typing import List, Union, Optional
from typing import Iterable, List, Tuple, Union, Optional

import numpy as np

import astropy.units as u
from astropy.time import Time
from astropy.coordinates import (
BaseCoordinateFrame,
SkyCoord,
get_body_barycentric_posvel,
)

from .dynamics import State, SolarGravity, SolarGravityAndRadiationPressure

Expand Down Expand Up @@ -55,27 +60,62 @@ class Syndynes:
ages : ~astropy.units.Quantity, optional
Array of particle ages (time).
observer : State, optional
State vector of the observer in the same reference frame as ``source``.
Default is the Earth obtained via ``astropy.coordinates.get_body``.
"""

def __init__(
self,
source: State,
betas: Optional[Union[np.ndarray, u.Quantity]],
ages: Optional[u.Quantity],
betas: Union[Iterable, u.Quantity[u.dimensionless_unscaled]],
ages: u.Quantity[u.s],
observer: Optional[State] = None,
) -> None:
if len(source) != 1:
raise ValueError("Only one source state vector allowed.")

self.source: State = source
self.betas: u.Quantity[""] = u.Quantity(betas, "").reshape((-1,))
self.ages: u.Quantity["s"] = u.Quantity(ages, "s").reshape((-1,))
self.betas: u.Quantity[u.dimensionless_unscaled] = u.Quantity(
betas, ""
).reshape((-1,))
self.ages: u.Quantity[u.s] = u.Quantity(ages, "s").reshape((-1,))

self.observer: State
if observer is None:
# use the Earth
r_e: SkyCoord
v_e: SkyCoord
t: Time = source.t.reshape(())
r_e, v_e = get_body_barycentric_posvel("earth", t)
self.observer = State.from_skycoord(
SkyCoord(
x=r_e.x,
y=r_e.y,
z=r_e.z,
v_x=v_e.x,
v_y=v_e.y,
v_z=v_e.z,
obstime=t,
frame="icrs",
representation_type="cartesian",
)
)
# self.observer = State.from_skycoord(
# get_body("earth", source.t).transform_to(source.frame)
# )
elif observer.frame != source.frame:
raise ValueError("source and observer frames are not equal.")
else:
self.observer = observer

self.solve()

def __repr__(self) -> str:
return f"<Syndynes: {len(self.betas)} beta values, {len(self.ages)} time steps>"

def initialize_states(self) -> None:
def _initialize_states(self) -> None:
"""Generate the initial particle states."""

# integrate from observation time, t_f, back to t_i
Expand All @@ -97,17 +137,19 @@ def solve(self) -> None:

logger: logging.Logger = logging.getLogger()

self.initialize_states()
self._initialize_states()

self.r: np.ndarray = np.zeros((self.betas.size, self.ages.size, 3))
particles: List[State] = []
t0: float = time.monotonic()
for i in range(self.betas.size):
for j in range(self.ages.size):
state: State = SolarGravityAndRadiationPressure.solve(
self.initial_states[j], self.source.t, self.betas[i]
particles.append(
SolarGravityAndRadiationPressure.solve(
self.initial_states[j], self.source.t, self.betas[i]
)
)
self.r[i, j] = state.r
t1: float = time.monotonic()
self.particles = State.from_states(particles)

logger.info(
"Solved for %d syndynes, %d time steps each.",
Expand All @@ -120,15 +162,25 @@ def solve(self) -> None:
def syndynes(self):
"""Iterator for each syndyne."""

for i in range(self.betas.size):
yield self.r[i]
n: int = self.ages.size
i: int
beta: float
for i, beta in enumerate(self.betas):
syn = self.particles[i * n : i + n]
coords = syn.observe(self.observer)
yield beta, syn, coords

@property
def synchrones(self):
"""Iterator for each synchrone."""

for i in range(self.ages.size):
yield self.r[:, i]
n: int = self.betas.size
i: int
age: u.Quantity[u.s]
for i, age in enumerate(self.ages):
syn = self.particles[i::n]
coords = syn.observe(self.observer)
yield age, syn, coords

def get_syndyne(self, beta: float) -> np.ndarray:
"""Get the positions of a single syndyne.
Expand Down Expand Up @@ -179,3 +231,6 @@ def get_synchrone(self, age: u.Quantity[u.s]) -> np.ndarray:
raise ValueError(f"Age not found: {age}")

return self.r[:, i]

def get_orbit(self, ages: u.Quantity[u.s]) -> Tuple[State, SkyCoord]:
pass
Loading

0 comments on commit e785a13

Please sign in to comment.