Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constrain type in to_vec(::AbstractArray/Vector) #156

Merged
merged 13 commits into from
Apr 28, 2021
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.2"
version = "0.13.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
9 changes: 4 additions & 5 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ function to_vec(x::T) where {T}
v, vals_from_vec = to_vec(vals)
function structtype_from_vec(v::Vector{<:Real})
val_vecs = vals_from_vec(v)
vals = map((b, v) -> b(v), backs, val_vecs)
return T(vals...)
values = map((b, v) -> b(v), backs, val_vecs)
return T(values...)
end
return v, structtype_from_vec
end

function to_vec(x::AbstractVector)
function to_vec(x::DenseVector)
x_vecs_and_backs = map(to_vec, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
Expand All @@ -53,7 +53,7 @@ function to_vec(x::AbstractVector)
return x_vec, Vector_from_vec
end

function to_vec(x::AbstractArray)
function to_vec(x::DenseArray)
x_vec, from_vec = to_vec(vec(x))

function Array_from_vec(x_vec)
Expand All @@ -63,7 +63,6 @@ function to_vec(x::AbstractArray)
return x_vec, Array_from_vec
end


# Some specific subtypes of AbstractArray.
function to_vec(x::Base.ReshapedArray{<:Any, 1})
x_vec, from_vec = to_vec(parent(x))
Expand Down
17 changes: 7 additions & 10 deletions test/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ end
Base.:(==)(x::DummyType, y::DummyType) = x.X == y.X
Base.length(x::DummyType) = size(x.X, 1)

# A dummy FillVector. This is a type for which the fallback implementation of
# `to_vec` should fail loudly.
# A dummy FillVector
struct FillVector <: AbstractVector{Float64}
x::Float64
len::Int
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x

# For testing Composite{ThreeFields}
struct ThreeFields
a
Expand All @@ -32,9 +34,6 @@ struct Nested
y::Singleton
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x

function test_to_vec(x::T; check_inferred = true) where {T}
check_inferred && @inferred to_vec(x)
x_vec, back = to_vec(x)
Expand Down Expand Up @@ -67,8 +66,8 @@ end
test_to_vec(UpperTriangular(randn(T, 13, 13)))
test_to_vec(Diagonal(randn(T, 7)))
test_to_vec(DummyType(randn(T, 2, 9)))
test_to_vec(SVector{2, T}(1.0, 2.0))
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0))
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred = false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred = false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred = false)
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false)

Bluetyle spacing in kwargs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed in other places as well now

test_to_vec(@view randn(T, 10)[1:4]) # SubArray -- Vector
test_to_vec(@view randn(T, 10, 2)[1:4, :]) # SubArray -- Matrix
test_to_vec(Base.ReshapedArray(rand(T, 3, 3), (9,), ()))
Expand Down Expand Up @@ -173,9 +172,7 @@ end
end

@testset "FillVector" begin
x = FillVector(5.0, 10)
x_vec, from_vec = to_vec(x)
@test_throws MethodError from_vec(randn(10))
test_to_vec(FillVector(5.0, 10); check_inferred=false)
end

@testset "fallback" begin
Expand Down