Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove AbstractTOMLDict #197

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/src/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ CurrentModule = ClimaParams
## Parameter dictionaries

```@docs
AbstractTOMLDict
ParamDict
```

Expand Down
1 change: 0 additions & 1 deletion src/ClimaParams.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module ClimaParams
using TOML
using DocStringExtensions

export AbstractTOMLDict
export ParamDict

export float_type,
Expand Down
92 changes: 38 additions & 54 deletions src/file_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
"""
AbstractTOMLDict{FT <: AbstractFloat}

Abstract parameter dict. One subtype:
- [`ParamDict`](@ref)
"""
abstract type AbstractTOMLDict{FT <: AbstractFloat} end

const NAMESTYPE =
Union{AbstractVector{S}, NTuple{N, S} where {N}} where {S <: AbstractString}

Expand All @@ -21,27 +13,29 @@

$(DocStringExtensions.FIELDS)
"""
struct ParamDict{FT} <: AbstractTOMLDict{FT}
struct ParamDict{FT <: AbstractFloat}
"dictionary representing a default/merged parameter TOML file"
data::Dict
"either a nothing, or a dictionary representing an override parameter TOML file"
override_dict::Union{Nothing, Dict}
end

"""
float_type(::AbstractTOMLDict)
float_type(::ParamDict)

The float type from the parameter dict.
"""
float_type(::AbstractTOMLDict{FT}) where {FT} = FT
float_type(::ParamDict{FT}) where {FT} = FT

Base.iterate(pd::ParamDict, state) = Base.iterate(pd.data, state)
Base.iterate(pd::ParamDict) = Base.iterate(pd.data)

Base.getindex(pd::ParamDict, i) = getindex(pd.data, i)

Base.print(td::ParamDict, io = stdout) = TOML.print(io, td.data)

Check warning on line 35 in src/file_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/file_parsing.jl#L35

Added line #L35 was not covered by tests

"""
log_component!(pd::AbstractTOMLDict, names, component)
log_component!(pd::ParamDict, names, component)

Adds a new key,val pair: `("used_in",component)` to each
named parameter in `pd`.
Expand Down Expand Up @@ -75,12 +69,7 @@
- `String` if type=\"string\"
Default type of `String` is used if no type is provided.
"""
function _get_typed_value(
pd::AbstractTOMLDict,
val,
valname::AbstractString,
valtype,
)
function _get_typed_value(pd::ParamDict, val, valname::AbstractString, valtype)

if valtype == "float"
return float_type(pd)(val)
Expand All @@ -103,7 +92,7 @@
end

"""
get_values(pd::AbstractTOMLDict, names)
get_values(pd::ParamDict, names)

Gets the values of the parameters in `names` from the TOML dict `pd`.
"""
Expand All @@ -129,13 +118,13 @@

"""
get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
names::Union{String,Vector{String}},
component::String
)

get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Union{Dict, Vector{Pair}, NTuple{N, Pair}, Vararg{Pair}},
component::String
)
Expand All @@ -149,15 +138,15 @@
from the long names and returns a NamedTuple where the keys are the variable names.
"""
function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
names::AbstractString,
component = nothing,
)
return get_parameter_values(pd, [names], component)
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
names::NAMESTYPE,
component::Union{AbstractString, Nothing} = nothing,
)
Expand All @@ -168,15 +157,15 @@
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Union{AbstractVector{Pair{S, S}}, NTuple{N, Pair}},
component = nothing,
) where {S, N}
return get_parameter_values(pd, Dict(name_map), component)
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Vararg{Pair};
component = nothing,
)
Expand All @@ -188,7 +177,7 @@
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Dict{S, S},
component = nothing,
) where {S <: AbstractString}
Expand All @@ -201,15 +190,15 @@
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::NamedTuple,
component = nothing,
)
return get_parameter_values(pd, Dict(pairs(name_map)), component)
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Dict{Symbol, Symbol},
component = nothing,
)
Expand Down Expand Up @@ -287,14 +276,13 @@
Throws warnings in each where parameters are not used. Also throws
an error if `strict == true` .
"""
check_override_parameter_usage(pd::AbstractTOMLDict, strict::Bool) =
check_override_parameter_usage(pd::ParamDict, strict::Bool) =
check_override_parameter_usage(pd, strict, pd.override_dict)

check_override_parameter_usage(pd::AbstractTOMLDict, strict::Bool, ::Nothing) =
nothing
check_override_parameter_usage(pd::ParamDict, strict::Bool, ::Nothing) = nothing

function check_override_parameter_usage(
pd::AbstractTOMLDict,
pd::ParamDict,
strict::Bool,
override_dict,
)
Expand Down Expand Up @@ -329,12 +317,12 @@
end

"""
write_log_file(pd::AbstractTOMLDict, filepath)
write_log_file(pd::ParamDict, filepath)

Writes a log file of all used parameters of `pd` at
the `filepath`. This file can be used to rerun the experiment.
"""
function write_log_file(pd::AbstractTOMLDict, filepath::AbstractString)
function write_log_file(pd::ParamDict, filepath::AbstractString)
used_parameters = Dict()
for (key, val) in pd.data
if "used_in" in keys(val)
Expand All @@ -349,7 +337,7 @@

"""
log_parameter_information(
pd::AbstractTOMLDict,
pd::ParamDict,
filepath;
strict::Bool = false
)
Expand All @@ -360,7 +348,7 @@
If `strict = true`, errors if override parameters are unused.
"""
function log_parameter_information(
pd::AbstractTOMLDict,
pd::ParamDict,
filepath::AbstractString;
strict::Bool = false,
)
Expand All @@ -372,17 +360,17 @@

"""
merge_override_default_values(
override_toml_dict::AbstractTOMLDict{FT},
default_toml_dict::AbstractTOMLDict{FT}
override_toml_dict::ParamDict,
default_toml_dict::ParamDict
) where {FT}

Combines the `default_toml_dict` with the `override_toml_dict`,
precedence is given to override information.
"""
function merge_override_default_values(
override_toml_dict::PDT,
default_toml_dict::PDT,
) where {FT, PDT <: AbstractTOMLDict{FT}}
override_toml_dict::ParamDict{FT},
default_toml_dict::ParamDict{FT},
) where {FT}
data = default_toml_dict.data
override_dict = override_toml_dict.override_dict
for (key, val) in override_toml_dict.data
Expand All @@ -394,7 +382,7 @@
end
end
end
return PDT(data, override_dict)
return ParamDict{FT}(data, override_dict)
end

"""
Expand Down Expand Up @@ -425,17 +413,13 @@
return merge_override_default_values(override_toml_dict, default_toml_dict)
end

# Extend Base.print to AbstractTOMLDict
Base.print(td::AbstractTOMLDict, io = stdout) = TOML.print(io, td.data)


"""
get_tagged_parameter_names(pd::AbstractTOMLDict, tag::AbstractString)
get_tagged_parameter_names(pd::AbstractTOMLDict, tags::Vector{AbstractString})
get_tagged_parameter_names(pd::ParamDict, tag::AbstractString)
get_tagged_parameter_names(pd::ParamDict, tags::Vector{AbstractString})

Returns a list of the parameters with a given tag.
"""
function get_tagged_parameter_names(pd::AbstractTOMLDict, tag::AbstractString)
function get_tagged_parameter_names(pd::ParamDict, tag::AbstractString)
data = pd.data
ret_values = String[]
for (key, val) in data
Expand All @@ -447,7 +431,7 @@
end

get_tagged_parameter_names(
pd::AbstractTOMLDict,
pd::ParamDict,
tags::Vector{S},
) where {S <: AbstractString} =
vcat(map(x -> get_tagged_parameter_names(pd, x), tags)...)
Expand All @@ -464,16 +448,16 @@
end

"""
get_tagged_parameter_values(pd::AbstractTOMLDict, tag::AbstractString)
get_tagged_parameter_values(pd::AbstractTOMLDict, tags::Vector{AbstractString})
get_tagged_parameter_values(pd::ParamDict, tag::AbstractString)
get_tagged_parameter_values(pd::ParamDict, tags::Vector{AbstractString})

Returns a list of name-value Pairs of the parameters with the given tag(s).
"""
get_tagged_parameter_values(pd::AbstractTOMLDict, tag::AbstractString) =
get_tagged_parameter_values(pd::ParamDict, tag::AbstractString) =
get_parameter_values(pd, get_tagged_parameter_names(pd, tag))

get_tagged_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
tags::Vector{S},
) where {S <: AbstractString} =
merge(map(x -> get_tagged_parameter_values(pd, x), tags)...)
Loading