Skip to content

Commit

Permalink
Merge pull request #555 from willow-ahrens/wma/fix-fileio
Browse files Browse the repository at this point in the history
fixing binsparse
  • Loading branch information
willow-ahrens authored May 14, 2024
2 parents 9052e4b + 653c966 commit e6622c4
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 143 deletions.
2 changes: 1 addition & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export lazy, compute, tensordot, @einsum

export choose, minby, maxby, overwrite, initwrite, filterop, d

export default, AsArray
export default, AsArray, expanddims

export parallelAnalysis, ParallelAnalysisResults
export parallel, realextent, extent, dimless
Expand Down
197 changes: 101 additions & 96 deletions src/interface/fileio/binsparse.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const BINSPARSE_VERSION = 0.1

"""
bspwrite(::AbstractString, tns)
bspwrite(::HDF5.File, tns)
Expand Down Expand Up @@ -88,161 +90,161 @@ function bspwrite_data_helper(f, desc, key, data::AbstractVector{Complex{T}}) wh
desc["data_types"][key] = "complex[$(desc["data_types"][key])]"
end

bspread_format_lookup = OrderedDict(
bspread_tensor_lookup = OrderedDict(
"DVEC" => OrderedDict(
"subformat" => OrderedDict(
"level" => "dense",
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
),

"DMAT" => OrderedDict(
"subformat" => OrderedDict(
"level" => "dense",
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "dense",
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
)
),

"DMATR" => OrderedDict(
"subformat" => OrderedDict(
"level" => "dense",
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "dense",
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
)
),

"DMATC" => OrderedDict(
"swizzle" => [1, 0],
"subformat" => OrderedDict(
"level" => "dense",
"transpose" => [1, 0],
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "dense",
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
)
),

"CVEC" => OrderedDict(
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
),

"CSR" => OrderedDict(
"subformat" => OrderedDict(
"level" => "dense",
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
)
),

"CSC" => OrderedDict(
"swizzle" => [1, 0],
"subformat" => OrderedDict(
"level" => "dense",
"transpose" => [1, 0],
"level" => OrderedDict(
"level_kind" => "dense",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
)
),

"DCSR" => OrderedDict(
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
)
),

"DCSC" => OrderedDict(
"swizzle" => [1, 0],
"subformat" => OrderedDict(
"level" => "sparse",
"transpose" => [1, 0],
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 1,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
)
),

"COO" => OrderedDict(
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 2,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
),

"COOR" => OrderedDict(
"subformat" => OrderedDict(
"level" => "sparse",
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 2,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
),

"COOC" => OrderedDict(
"swizzle" => [1, 0],
"subformat" => OrderedDict(
"level" => "sparse",
"transpose" => [1, 0],
"level" => OrderedDict(
"level_kind" => "sparse",
"rank" => 2,
"subformat" => OrderedDict(
"level" => "element",
"level" => OrderedDict(
"level_kind" => "element",
)
)
),
)

bspwrite_format_lookup = OrderedDict(v => k for (k, v) in bspread_format_lookup)
bspwrite_format_lookup = OrderedDict(v => k for (k, v) in bspread_tensor_lookup)

#indices_zero_to_one(vec::Vector{Ti}) where {Ti} = PlusOneVector(vec)
indices_zero_to_one(vec::Vector) = vec .+ one(eltype(vec))
Expand Down Expand Up @@ -271,20 +273,23 @@ bspwrite_tensor(io, fbr::Tensor, attrs = OrderedDict()) =

function bspwrite_tensor(io, arr::SwizzleArray{dims, <:Tensor}, attrs = OrderedDict()) where {dims}
desc = OrderedDict(
"format" => OrderedDict{Any, Any}(
"subformat" => OrderedDict(),
"tensor" => OrderedDict{Any, Any}(
"level" => OrderedDict(),
),
"fill" => true,
"shape" => map(Int, size(arr)),
"data_types" => OrderedDict(),
"version" => "0.1",
"version" => "$BINSPARSE_VERSION",
"number_of_stored_values" => countstored(arr),
"attrs" => attrs,
)
if !issorted(reverse(collect(dims)))
desc["format"]["swizzle"] = reverse(collect(dims)) .- 1
desc["tensor"]["transpose"] = reverse(collect(dims)) .- 1
end
bspwrite_level(io, desc, desc["tensor"]["level"], arr.body.lvl)
if haskey(bspwrite_format_lookup, desc["tensor"])
desc["format"] = bspwrite_format_lookup[desc["tensor"]]
end
bspwrite_level(io, desc, desc["format"]["subformat"], arr.body.lvl)
desc["format"] = get(bspwrite_format_lookup, desc["format"], desc["format"])
bspwrite_header(io, json(Dict("binsparse" => desc), 4))
end

Expand All @@ -307,29 +312,29 @@ function bspread_header end

function bspread(f)
desc = bspread_header(f)["binsparse"]
@assert desc["version"] == "0.1"
fmt = OrderedDict{Any, Any}(get(bspread_format_lookup, desc["format"], desc["format"]))
if !haskey(fmt, "swizzle")
fmt["swizzle"] = collect(0:length(desc["shape"]) - 1)
@assert desc["version"] == "$BINSPARSE_VERSION"
fmt = OrderedDict{Any, Any}(get(() -> desc["tensor"], bspread_tensor_lookup, desc["format"]))
if !haskey(fmt, "transpose")
fmt["transpose"] = collect(0:length(desc["shape"]) - 1)
end
if !issorted(reverse(fmt["swizzle"]))
sigma = sortperm(reverse(fmt["swizzle"] .+ 1))
if !issorted(reverse(fmt["transpose"]))
sigma = sortperm(reverse(fmt["transpose"] .+ 1))
desc["shape"] = desc["shape"][sigma]
end
fbr = Tensor(bspread_level(f, desc, fmt["subformat"]))
if !issorted(reverse(fmt["swizzle"]))
fbr = swizzle(fbr, reverse(fmt["swizzle"] .+ 1)...)
fbr = Tensor(bspread_level(f, desc, fmt["level"]))
if !issorted(reverse(fmt["transpose"]))
fbr = swizzle(fbr, reverse(fmt["transpose"] .+ 1)...)
end
if haskey(desc, "structure")
throw(ArgumentError("binsparse structure field currently unsupported"))
end
fbr
end

bspread_level(f, desc, fmt) = bspread_level(f, desc, fmt, Val(Symbol(fmt["level"])))
bspread_level(f, desc, fmt) = bspread_level(f, desc, fmt, Val(Symbol(fmt["level_kind"])))

function bspwrite_level(f, desc, fmt, lvl::ElementLevel{D}) where {D}
fmt["level"] = "element"
fmt["level_kind"] = "element"
bspwrite_data(f, desc, "values", lvl.val)
bspwrite_data(f, desc, "fill_value", [D])
end
Expand All @@ -344,13 +349,13 @@ function bspread_level(f, desc, fmt, ::Val{:element})
end

function bspwrite_level(f, desc, fmt, lvl::DenseLevel{D}) where {D}
fmt["level"] = "dense"
fmt["level_kind"] = "dense"
fmt["rank"] = 1
fmt["subformat"] = OrderedDict()
bspwrite_level(f, desc, fmt["subformat"], lvl.lvl)
fmt["level"] = OrderedDict()
bspwrite_level(f, desc, fmt["level"], lvl.lvl)
end
function bspread_level(f, desc, fmt, ::Val{:dense})
lvl = bspread_level(f, desc, fmt["subformat"])
lvl = bspread_level(f, desc, fmt["level"])
R = fmt["rank"]
for r = 1:R
n = level_ndims(typeof(lvl))
Expand All @@ -361,19 +366,19 @@ function bspread_level(f, desc, fmt, ::Val{:dense})
end

function bspwrite_level(f, desc, fmt, lvl::SparseListLevel)
fmt["level"] = "sparse"
fmt["level_kind"] = "sparse"
fmt["rank"] = 1
n = level_ndims(typeof(lvl))
N = length(desc["shape"])
if N - n > 0
bspwrite_data(f, desc, "pointers_to_$(N - n)", indices_one_to_zero(lvl.ptr))
end
bspwrite_data(f, desc, "indices_$(N - n)", indices_one_to_zero(lvl.idx))
fmt["subformat"] = OrderedDict()
bspwrite_level(f, desc, fmt["subformat"], lvl.lvl)
fmt["level"] = OrderedDict()
bspwrite_level(f, desc, fmt["level"], lvl.lvl)
end
function bspwrite_level(f, desc, fmt, lvl::SparseCOOLevel{R}) where {R}
fmt["level"] = "sparse"
fmt["level_kind"] = "sparse"
fmt["rank"] = R
n = level_ndims(typeof(lvl))
N = length(desc["shape"])
Expand All @@ -383,12 +388,12 @@ function bspwrite_level(f, desc, fmt, lvl::SparseCOOLevel{R}) where {R}
for r = 1:R
bspwrite_data(f, desc, "indices_$(N - n + r - 1)", indices_one_to_zero(lvl.tbl[r]))
end
fmt["subformat"] = OrderedDict()
bspwrite_level(f, desc, fmt["subformat"], lvl.lvl)
fmt["level"] = OrderedDict()
bspwrite_level(f, desc, fmt["level"], lvl.lvl)
end
function bspread_level(f, desc, fmt, ::Val{:sparse})
R = fmt["rank"]
lvl = bspread_level(f, desc, fmt["subformat"])
lvl = bspread_level(f, desc, fmt["level"])
n = level_ndims(typeof(lvl)) + R
N = length(desc["shape"])
tbl = (map(1:R) do r
Expand Down
2 changes: 2 additions & 0 deletions src/tensors/combinators/swizzle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ default(arr::SwizzleArray) = default(typeof(arr))
default(::Type{SwizzleArray{dims, Body}}) where {dims, Body} = default(Body)
Base.similar(arr::SwizzleArray{dims}) where {dims} = SwizzleArray{dims}(similar(arr.body))

countstored(arr::SwizzleArray) = countstored(arr.body)

Base.size(arr::SwizzleArray{dims}) where {dims} = map(n->size(arr.body)[n], dims)

Base.show(io::IO, ex::SwizzleArray) = Base.show(io, MIME"text/plain"(), ex)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MatrixDepot = "b51810bb-c9f3-55da-ae3c-350fc1fbce05"
MatrixMarket = "4d4711f2-db25-561a-b6b3-d35e7d4047d3"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand Down
Loading

0 comments on commit e6622c4

Please sign in to comment.