Skip to content

Commit

Permalink
Added symbol input for coefficients, improved tests, assert BIa in BIb
Browse files Browse the repository at this point in the history
  • Loading branch information
Vizia128 committed Jul 25, 2023
1 parent 442133e commit ee19ae0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 53 deletions.
6 changes: 3 additions & 3 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,9 @@ function broadcasted(::typeof(/), a::MultiVector{CA,Ta,BI}, b::MultiVector{CA,Tb
end

function broadcasted(::typeof(/), a::MultiVector{CA,Ta,BIa}, b::MultiVector{CA,Tb,BIb})::MultiVector where {CA,Ta,Tb,BIa,BIb}
BI = tuple(union(BIa, BIb)...)
v1, v2 = coefficients(a, BI), coefficients(b, BI)
return MultiVector(CA, BI, v1 ./ v2)
@assert BIa BIb
v1, v2 = coefficients(a, BIb), coefficients(b, BIb)
return MultiVector(CA, BIb, v1 ./ v2)
end

"""
Expand Down
27 changes: 23 additions & 4 deletions src/multivector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ coefficients(mv::MultiVector) = getfield(mv,:c)


"""
coefficient(::MultiVector, n::Union{NTuple, AbstractVector})
coefficients(::MultiVector, n::Union{NTuple{Integer}, AbstractVector{Integer}})
Returns the multivector coefficients for the given basis tuple or vector.
Returns the multivector coefficients for the given tuple/vector of basis indices.
Returns 0 if index is out of bounds.
"""
function coefficients(
mv::MultiVector{CA,T},
idxs::I
)::I where {CA, T, I<:Union{NTuple, AbstractVector}}
idxs::U
) where {CA, T, U<:Union{NTuple{N,I} where {N}, AbstractVector{I}} where {I<:Integer}}
bases = baseindices(mv)
coeffs = getfield(mv, :c)

Expand All @@ -124,6 +124,25 @@ function coefficients(
end
end

"""
coefficients(::MultiVector, n::Union{NTuple{Symbol}, AbstractVector{Symbol}})
Returns the multivector coefficients for the given tuple/vector of basis symbols.
Returns 0 if the symbol is not a valid basis symbol.
"""
function coefficients(
mv::MultiVector{CA,T},
syms::U
) where {CA, T, U<:Union{NTuple{N,Symbol} where {N}, AbstractVector{Symbol}}}
bases = baseindices(mv)
coeffs = getfield(mv, :c)

return map(syms) do sym
n = findfirst(i -> isequal(sym, basesymbol(CA,i)), bases)
isnothing(n) ? zero(T) : coeffs[n]
end
end


"""
coefficient(::MultiVector, n::Integer)
Expand Down
92 changes: 46 additions & 46 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,67 +456,67 @@ import LinearAlgebra.SingularException

@testset "coefficients" begin
pga = typeof(CliffordAlgebra(:PGA3D))
mv = MultiVector(pga, (1:16...,), (101:116...,))
mv_int = MultiVector(pga, (1:16...,), (101:116...,))
mv_float = MultiVector(pga, (1:16...,), (101.5:116.5...,))

fib_ntuple = (1, 2, 3, 5, 8, 13)
fib_vec = [1, 2, 3, 5, 8, 13]
sym_ntuple = (:𝟏, :e1, :e2, :e0, :e2e3, :e1e2e0)
sym_vec = [:𝟏, :e1, :e2, :e0, :e2e3, :e1e2e0]

@test coefficients(mv_int, fib_ntuple) == fib_ntuple .+ 100
@test coefficients(mv_int, fib_vec) == fib_vec .+ 100
@test coefficients(mv_float, fib_ntuple) == fib_ntuple .+ 100.5
@test coefficients(mv_float, fib_vec) == fib_vec .+ 100.5

@test coefficients(mv, fib_ntuple) == fib_ntuple .+ 100
@test coefficients(mv, fib_vec) == fib_vec .+ 100
@test coefficients(mv_int, sym_ntuple) == fib_ntuple .+ 100
@test coefficients(mv_int, sym_vec) == fib_vec .+ 100
@test coefficients(mv_float, sym_ntuple) == fib_ntuple .+ 100.5
@test coefficients(mv_float, sym_vec) == fib_vec .+ 100.5
end

@testset "broadcasted" begin
@testset "broadcasted .*" begin
pga = typeof(CliffordAlgebra(:PGA2D))
mvs = [
MultiVector(pga, (1:3...,), (2,2,2)),
MultiVector(pga, (1:6...,), (3,3,3,3,3,3)),
MultiVector(pga, (4:8...,), (5,5,5,5,5)),
MultiVector(pga, (1:8...,), (7,7,7,7,7,7,7,7)),
]

mv_muls = [
MultiVector(pga, (1, 2, 3), (4, 4, 4))
MultiVector(pga, (1, 2, 3), (6, 6, 6))
MultiVector(pga, (1, 2, 3, 4, 5, 6, 7, 8), (0, 0, 0, 0, 0, 0, 0, 0))
MultiVector(pga, (1, 2, 3), (14, 14, 14))
MultiVector(pga, (1, 2, 3), (6, 6, 6))
MultiVector(pga, (1, 2, 3, 4, 5, 6), (9, 9, 9, 9, 9, 9))
MultiVector(pga, (4, 5, 6), (15, 15, 15))
MultiVector(pga, (1, 2, 3, 4, 5, 6), (21, 21, 21, 21, 21, 21))
MultiVector(pga, (1, 2, 3, 4, 5, 6, 7, 8), (0, 0, 0, 0, 0, 0, 0, 0))
MultiVector(pga, (4, 5, 6), (15, 15, 15))
MultiVector(pga, (4, 5, 6, 7, 8), (25, 25, 25, 25, 25))
MultiVector(pga, (4, 5, 6, 7, 8), (35, 35, 35, 35, 35))
MultiVector(pga, (1, 2, 3), (14, 14, 14))
MultiVector(pga, (1, 2, 3, 4, 5, 6), (21, 21, 21, 21, 21, 21))
MultiVector(pga, (4, 5, 6, 7, 8), (35, 35, 35, 35, 35))
MultiVector(pga, (1, 2, 3, 4, 5, 6, 7, 8), (49, 49, 49, 49, 49, 49, 49, 49))
]

mv_divs = [
MultiVector(pga, (1, 2, 3), (1.0, 1.0, 1.0))
MultiVector(pga, (1, 2, 3), (0.6666666666666666, 0.6666666666666666, 0.6666666666666666))
MultiVector(pga, (1, 2, 3), (Inf, Inf, Inf))
MultiVector(pga, (1, 2, 3), (0.2857142857142857, 0.2857142857142857, 0.2857142857142857))
MultiVector(pga, (1, 2, 3, 4, 5, 6), (1.5, 1.5, 1.5, Inf, Inf, Inf))
MultiVector(pga, (1, 2, 3, 4, 5, 6), (1.0, 1.0, 1.0, 1.0, 1.0, 1.0))
MultiVector(pga, (1, 2, 3, 4, 5, 6), (Inf, Inf, Inf, 0.6, 0.6, 0.6))
MultiVector(pga, (1, 2, 3, 4, 5, 6), (0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855))
MultiVector(pga, (4, 5, 6, 7, 8), (Inf, Inf, Inf, Inf, Inf))
MultiVector(pga, (4, 5, 6, 7, 8), (1.6666666666666667, 1.6666666666666667, 1.6666666666666667, Inf, Inf))
MultiVector(pga, (4, 5, 6, 7, 8), (1.0, 1.0, 1.0, 1.0, 1.0))
MultiVector(pga, (4, 5, 6, 7, 8), (0.7142857142857143, 0.7142857142857143, 0.7142857142857143, 0.7142857142857143, 0.7142857142857143))
MultiVector(pga, (1, 2, 3, 4, 5, 6, 7, 8), (3.5, 3.5, 3.5, Inf, Inf, Inf, Inf, Inf))
MultiVector(pga, (1, 2, 3, 4, 5, 6, 7, 8), (2.3333333333333335, 2.3333333333333335, 2.3333333333333335, 2.3333333333333335, 2.3333333333333335, 2.3333333333333335, Inf, Inf))
MultiVector(pga, (1, 2, 3, 4, 5, 6, 7, 8), (Inf, Inf, Inf, 1.4, 1.4, 1.4, 1.4, 1.4))
MultiVector(pga, (1, 2, 3, 4, 5, 6, 7, 8), (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0))
]

for ((mv1, mv2), mv_mul, mv_div) in zip(Iterators.product(mvs, mvs), mv_muls, mv_divs)
@test isapprox((mv1 .* mv2), mv_mul)
@test isapprox(vector(mv2 ./ mv1), vector(mv_div))

for mv1 in mvs, mv2 in mvs
@test isapprox(vector(mv1 .* mv2), vector(mv1) .* vector(mv2))
end
end

@testset "broadcasted ./" begin
pga = typeof(CliffordAlgebra(:PGA2D))

function semi_safe_divide(x, y)
if x == 0 && y == 0
return 0
else
return x / y
end
end

mv_zero = MultiVector(pga, (1:3...,), (0, 0, 0))
mv_small = MultiVector(pga, (1:3...,), (2, 2, 2))
mv_full = MultiVector(pga, (1:8...,), (7, 7, 7, 7, 7, 7, 7, 7))

@test isapprox(vector(mv_small ./ mv_zero), [Inf, Inf, Inf, 0, 0, 0, 0, 0])
@test isapprox(vector(mv_small ./ mv_zero), semi_safe_divide.(vector(mv_small), vector(mv_zero)))
@test_throws AssertionError mv_full ./ mv_zero

@test isapprox(vector(mv_zero ./ mv_small), semi_safe_divide.(vector(mv_zero), vector(mv_small)))
@test isapprox(vector(mv_small ./ mv_small), semi_safe_divide.(vector(mv_small), vector(mv_small)))
@test_throws AssertionError mv_full ./ mv_small

@test isapprox(vector(mv_zero ./ mv_full), semi_safe_divide.(vector(mv_zero), vector(mv_full)))
@test isapprox(vector(mv_small ./ mv_full), semi_safe_divide.(vector(mv_small), vector(mv_full)))
@test isapprox(vector(mv_full ./ mv_full), semi_safe_divide.(vector(mv_full), vector(mv_full)))
end

@testset "norm" begin
r = -100:100
mv = rand(r) + rand(r) * e1 +
Expand Down

0 comments on commit ee19ae0

Please sign in to comment.