Skip to content

Commit

Permalink
add a few missing methods and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
exaexa committed Oct 7, 2023
1 parent 26b9f72 commit e2e855f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
26 changes: 17 additions & 9 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Base.@kwdef struct Tree{X}
construct the inner dictionary.
"""
Tree{X}(x...) where {X} = new{X}(SortedDict{Symbol,Union{X,Tree{X}}}(x...))

# TODO Tree could be a proper subtype of AbstractDict, but currently that
# fails due to circular use of the type in its own parameter. Might be hard
# to do that properly.
end

"""
Expand All @@ -31,28 +35,32 @@ simpler way to get the elements without an explicit use of `getfield`.
"""
elems(x::Tree) = getfield(x, :elems)

function Base.getproperty(x::Tree, sym::Symbol)
elems(x)[sym]
end

Base.keys(x::Tree) = keys(elems(x))

Base.values(x::Tree) = values(elems(x))
Base.isempty(x::Tree) = isempty(elems(x))

Base.length(x::Tree) = length(elems(x))

Base.iterate(x::Tree) = iterate(elems(x))
Base.iterate(x::Tree, st) = iterate(elems(x), st)

Base.eltype(x::Tree) = eltype(elems(x))

Base.keytype(x::Tree) = keytype(elems(x))

Base.keys(x::Tree) = keys(elems(x))

Base.haskey(x::Tree, sym::Symbol) = haskey(elems(x), sym)

Base.valtype(x::Tree) = valtype(elems(x))

Base.eltype(x::Tree) = eltype(elems(x))
Base.values(x::Tree) = values(elems(x))

Base.getindex(x::Tree, sym::Symbol) = getindex(elems(x), sym)

Base.propertynames(x::Tree) = keys(x)

Base.getindex(x::Tree, sym::Symbol) = getindex(elems(x), sym)
Base.hasproperty(x::Tree, sym::Symbol) = haskey(x, sym)

Base.getproperty(x::Tree, sym::Symbol) = elems(x)[sym]

#
# Algebraic construction
Expand Down
8 changes: 4 additions & 4 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ end
@testset "Solution tree operations" begin
ct = C.variables(keys = [:a, :b])

@test_throws BoundsError C.SolutionTree(ct, [1.0])
st = C.SolutionTree(ct, [123.0, 321.0])
@test_throws BoundsError C.ValueTree(ct, [1.0])
st = C.ValueTree(ct, [123.0, 321.0])

@test isempty(C.SolutionTree())
@test isempty(C.SolutionTree(C.ConstraintTree(), Float64[]))
@test isempty(C.ValueTree())
@test isempty(C.ValueTree(C.ConstraintTree(), Float64[]))
@test !isempty(st)
@test haskey(st, :a)
@test hasproperty(st, :a)
Expand Down

0 comments on commit e2e855f

Please sign in to comment.