Skip to content

Commit

Permalink
Incremental progress on rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
cadojo committed Apr 7, 2021
1 parent c15c04a commit 3fcc7a4
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 45 deletions.
69 changes: 53 additions & 16 deletions src/CommonTypes/CommonTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,60 @@ Implementations are provided in TwoBody, and NBody.
"""
module CommonTypes


macro boilerplate(struct_definition)
firstline_post_struct = split(split(string(struct_definition), "\n")[1], "struct ")[2]
if ' ' firstline_post_struct
structname = split(firstline_post_struct, " ") |> first
end
if '{' structname
structname = split(structname, "{") |> first
end

if '\n' structname
structname = split(firstline_post_struct, "\n") |> first
end

structname = structname |> Symbol

convert = :(Base.convert(::Type{T}, o::$(esc(structname))) where {T<:Number} = $(esc(structname))(map(x-> typeof(x) <: Union{AbstractArray{<:Number}, Number} ? T.(x) : x, fieldnames($(esc(structname))))...))
if '{' firstline_post_struct
promote = :(Base.promote(::Type{$(esc(structname)){A}}, ::Type{$(esc(structname)){B}}) where {A,B} = $(esc(structname)){promote_type(A,B)})
else
promote = :(Base.promote(::Type($(esc(structname)), ::Type{$(esc(structname))})) = $(esc(structname)))
end
Float16 = :(Core.Float16(o::$(esc(structname))) = convert(Float16, o))
Float32 = :(Core.Float32(o::$(esc(structname))) = convert(Float32, o))
Float64 = :(Core.Float64(o::$(esc(structname))) = convert(Float64, o))
BigFloat = :(Base.MPFR.BigFloat(o::$(esc(structname))) = convert(BigFloat, o))

quote
$struct_definition
$convert
$promote
$Float16
$Float32
$Float64
end
end

macro export_boilerplate()
return :(export convert, promote, Float16, Float32, Float64)
end

export AbstractBody, AbstractOrbitalSystem, AbstractTrajectory, AbstractCartesianState, CartesianState
export getindex, setindex!, lengthunit, timeunit, velocityunit
export position_vector, velocity_vector, scalar_position, scalar_velocity
# export @boilerplate, @export_boilerplate
# @export_boilerplate

using Reexport
@reexport using Unitful, UnitfulAngles, UnitfulAstro
using StaticArrays: StaticVector, MVector

include("../Misc/DocStringExtensions.jl")
include("../Misc/UnitfulAliases.jl")

export AbstractBody, AbstractOrbitalSystem, AbstractTrajectory, AbstractCartesianState, CartesianState
export getindex, setindex!, lengthunit, timeunit, velocityunit
export position_vector, velocity_vector, scalar_position, scalar_velocity
export convert, promote, Float16, Float32, Float64, BigFloat

"""
Abstract type for bodies in space: both `CelestialBody`s (in
`TwoBody.jl`), and `Body`s (in `NBody.jl`).
Expand All @@ -37,10 +79,7 @@ Abstract type describing an orbital state.
"""
abstract type AbstractCartesianState{F<:AbstractFloat} <: StaticVector{6,F} end

"""
Cartesian state which describes a spacecraft or body's position and velocity with respect to _something_.
"""
mutable struct CartesianState{F<:AbstractFloat, LU, TU} <: AbstractCartesianState{F} where {LU<:Unitful.LengthUnits, TU<:Unitful.TimeUnits}
mutable struct CartesianState{F<:AbstractFloat, LU, TU} <: AbstractCartesianState{F} where {LU<:Unitful.LengthFreeUnits, TU<:Unitful.TimeFreeUnits}
r::SubArray{F, 1, MVector{6, F}, Tuple{UnitRange{Int64}}, true}
v::SubArray{F, 1, MVector{6, F}, Tuple{UnitRange{Int64}}, true}
rv::MVector{6,F}
Expand Down Expand Up @@ -76,12 +115,10 @@ mutable struct CartesianState{F<:AbstractFloat, LU, TU} <: AbstractCartesianStat
end
end

Base.convert(::Type{T}, o::CartesianState) where {T<:AbstractFloat} = CartesianState(T.(o.r), T.(o.v))
Base.promote(::Type{CartesianState{A}}, ::Type{CartesianState{B}}) where {A<:AbstractFloat, B<:AbstractFloat} = CartesianState{promote_type(A,B)}
Core.Float16(o::CartesianState) = convert(Float16, o)
Core.Float32(o::CartesianState) = convert(Float32, o)
Core.Float64(o::CartesianState) = convert(Float64, o)
Base.MPFR.BigFloat(o::CartesianState) = convert(BigFloat, o)
@doc "Cartesian state which describes a spacecraft or body's position and velocity with respect to _something_." CartesianState

# We need beyond the boilerplate for `CartesianState`...
Base.convert(::Type{T}, o::CartesianState) where {T<:CartesianState} = CartesianState(o.r / upreferred((1 * T.parameters[2]) / (1 * lengthunit(o))), o.v / upreferred((1 * T.parameters[2]) / (1 * lengthunit(o))) * upreferred((1 * T.parameters[3]) * (1 * timeunit(o))); lengthunit = T.parameters[2], timeunit = T.parameters[3])

Base.getindex(state::CartesianState, i::Int) = state.rv[i]
Base.setindex!(state::CartesianState, value, i::Int) = (state.rv[i] = value)
Expand All @@ -101,7 +138,7 @@ Returns the `Unitful.Velocity` unit associated with the Cartesian state.
"""
velocityunit(state::CartesianState) = lengthunit(state) / timeunit(state)

function Base.show(io::IO, ::MIME"text/plain", X::CartesianState{F, LU, TU}) where {F, LU, TU}
function Base.show(io::IO, ::MIME"text/plain", state::CartesianState{F, LU, TU}) where {F, LU, TU}
println(io, "Cartesian State of type $(string(F)):")
println(io, " r = ", [state.r[1] state.r[2] state.r[3]], " ", string(LU))
println(io, " v = ", [state.v[1] state.v[2] state.v[3]], " ", string(LU/TU))
Expand Down
4 changes: 2 additions & 2 deletions src/TwoBody/TwoBodyCalculations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ end
Returns semimajor axis parameter, a.
"""
semimajor_axis(r, v, μ) = inv( (2 / r) - (v^2 / μ) )
semimajor_axis(orbit::CartesianOrbit) = semimajor_axis(scalar_position(orbit), scalar_velocity(orbit), orbit.body.μ)
semimajor_axis(orbit::KeplerianOrbit) = orbit.state.a
semimajor_axis(orbit::CartesianOrbit) = semimajor_axis(scalar_position(orbit), scalar_velocity(orbit), orbit.body.μ)
semimajor_axis(orbit::KeplerianOrbit) = orbit.state.a # TODO define these functions for `KeplerianState` and dispatch here!

"""
Returns specific angular momentum vector, h̅.
Expand Down
32 changes: 5 additions & 27 deletions src/TwoBody/TwoBodyStates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ solar system bodies are supported:
Sun, Mercury, Venus, Earth, Moon (Luna), Mars, Jupiter,
Saturn, Uranus, Neptune, Pluto.
"""
struct CelestialBody{F<:AbstractFloat}
R::Length{F}
μ::MassParameter{F}
struct CelestialBody{F<:AbstractFloat, LU, MU} <: AbstractBody where {LU <: Unitful.LengthUnits, MU <: MassParameterUnits}
R::F
μ::F
name::String

function CelestialBody(m::Mass{<:AbstractFloat}, R::Length{<:AbstractFloat}, name::String="")
Expand Down Expand Up @@ -79,21 +79,13 @@ struct CelestialBody{F<:AbstractFloat}

end

Base.convert(::Type{T}, b::CelestialBody) where {T<:AbstractFloat} = CelestialBody(T(b.μ), T(b.R), b.name)
Base.promote(::Type{CelestialBody{A}}, ::Type{CelestialBody{B}}) where {A<:AbstractFloat, B<:AbstractFloat} = CelestialBody{promote_type(A,B)}
Core.Float16(o::CelestialBody) = convert(Float16, o)
Core.Float32(o::CelestialBody) = convert(Float32, o)
Core.Float64(o::CelestialBody) = convert(Float64, o)
Base.MPFR.BigFloat(o::CelestialBody) = convert(BigFloat, o)

"""
Custom display for `CelestialBody` instances.
"""
function Base.show(io::IO, body::CelestialBody)

println(io, crayon"blue", "CelestialBody:")
println(io, crayon"default",
" Mass: ", ustrip(u"kg", body.μ / G), " ", u"kg")
println(io, "CelestialBody:")
println(io, " Mass: ", ustrip(u"kg", body.μ / G), " ", u"kg")
println(io, " Radius: ", ustrip(u"km", body.R), " ", u"km")
println(io, " Mass Parameter: ", ustrip(u"km^3/s^2", body.μ), " ", u"km^3/s^2")

Expand Down Expand Up @@ -145,13 +137,6 @@ struct KeplerianState{F<:AbstractFloat, LU, AU} <: AbstractKeplerianState{F} whe

end

Base.convert(::Type{T}, o::KeplerianState) where {T<:AbstractFloat} = KeplerianState(T(o.e), T(o.a), T(o.i), T(o.Ω), T(o.ω), T(o.ν))
Base.promote(::Type{KeplerianState{A}}, ::Type{KeplerianState{B}}) where {A<:AbstractFloat, B<:AbstractFloat} = KeplerianState{promote_type(A,B)}
Core.Float16(o::KeplerianState) = convert(Float16, o)
Core.Float32(o::KeplerianState) = convert(Float32, o)
Core.Float64(o::KeplerianState) = convert(Float64, o)
Base.MPFR.BigFloat(o::KeplerianState) = convert(BigFloat, o)

"""
Returns the `Unitful.Length` unit associated with the Keplerian state.
"""
Expand Down Expand Up @@ -189,13 +174,6 @@ struct RestrictedTwoBodySystem{C<:AbstractConic, F<:AbstractFloat, T<:Union{Cart

end

Base.convert(::Type{T}, o::RestrictedTwoBodySystem) where {T<:AbstractFloat} = RestrictedTwoBodySystem(convert(T, o.state), convert(T, o.body))
Base.promote(::Type{RestrictedTwoBodySystem{A}}, ::Type{RestrictedTwoBodySystem{B}}) where {A<:AbstractFloat, B<:AbstractFloat} = Orbit{promote_type(A,B)}
Core.Float16(o::RestrictedTwoBodySystem) = convert(Float16, o)
Core.Float32(o::RestrictedTwoBodySystem) = convert(Float32, o)
Core.Float64(o::RestrictedTwoBodySystem) = convert(Float64, o)
Base.MPFR.BigFloat(o::RestrictedTwoBodySystem) = convert(BigFloat, o)

"""
Alias for `RestrictedTwoBodySystem`.
"""
Expand Down

0 comments on commit 3fcc7a4

Please sign in to comment.