Skip to content

Commit

Permalink
Don't treat Vector{UInt8} as Arrow Binary type (#439)
Browse files Browse the repository at this point in the history
Fixes #411. Alternative to #419.

This PR should be compatible with or without the ArrowTypes changes. I
think it's fine to do compat things in Arrow like this as long as they
don't get out of hand and we can eventually remove them as we bump
required ArrowTypes versions and such.

The PR consists of not treating `Vector{UInt8}` as the Arrow Binary
type, which is meant for "binary string"s. Julia has a pretty good match
for that in `Base.CodeUnits`, so instead, we use that to write Binary
and `Vector{UInt8}` is treated as a regular List of Primitive UInt8
type.
  • Loading branch information
quinnj authored May 19, 2023
1 parent b2a832e commit 899ecb0
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/ArrowTypes/src/ArrowTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,12 @@ isstringtype(::ListKind{stringtype}) where {stringtype} = stringtype
isstringtype(::Type{ListKind{stringtype}}) where {stringtype} = stringtype

ArrowKind(::Type{<:AbstractString}) = ListKind{true}()
# Treate Base.CodeUnits as Binary arrow type
ArrowKind(::Type{<:Base.CodeUnits}) = ListKind{true}()

fromarrow(::Type{T}, ptr::Ptr{UInt8}, len::Int) where {T} = fromarrow(T, unsafe_string(ptr, len))
fromarrow(::Type{T}, x) where {T <: Base.CodeUnits} = Base.CodeUnits(x)
fromarrow(::Type{Union{Missing, Base.CodeUnits}}, x) = x === missing ? missing : Base.CodeUnits(x)

ArrowType(::Type{Symbol}) = String
toarrow(x::Symbol) = String(x)
Expand Down
3 changes: 3 additions & 0 deletions src/ArrowTypes/test/tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ end
@test !ArrowTypes.isstringtype(ArrowTypes.ListKind())
@test !ArrowTypes.isstringtype(typeof(ArrowTypes.ListKind()))
@test ArrowTypes.ArrowKind(String) == ArrowTypes.ListKind{true}()
@test ArrowTypes.ArrowKind(Base.CodeUnits) == ArrowTypes.ListKind{true}()

hey = collect(b"hey")
@test ArrowTypes.fromarrow(String, pointer(hey), 3) == "hey"
@test ArrowTypes.fromarrow(Base.CodeUnits, pointer(hey), 3) == b"hey"
@test ArrowTypes.fromarrow(Union{Base.CodeUnits, Missing}, pointer(hey), 3) == b"hey"

@test ArrowTypes.ArrowType(Symbol) == String
@test ArrowTypes.toarrow(:hey) == "hey"
Expand Down
44 changes: 33 additions & 11 deletions src/arraytypes/list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,20 @@ Base.size(l::List) = (l.ℓ,)
@inbounds lo, hi = l.offsets[i]
S = Base.nonmissingtype(T)
K = ArrowTypes.ArrowKind(ArrowTypes.ArrowType(S))
if ArrowTypes.isstringtype(K)
# special-case Base.CodeUnits for ArrowTypes compat
if ArrowTypes.isstringtype(K) || S <: Base.CodeUnits
if S !== T
return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
if S <: Base.CodeUnits
return l.validity[i] ? Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1)) : missing
else
return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
end
else
return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
if S <: Base.CodeUnits
return Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1))
else
return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
end
end
elseif S !== T
return l.validity[i] ? ArrowTypes.fromarrow(T, view(l.data, lo:hi)) : missing
Expand All @@ -66,6 +75,12 @@ end

# end

# internal interface definitions to be able to treat AbstractString/CodeUnits similarly
_ncodeunits(x::AbstractString) = ncodeunits(x)
_codeunits(x::AbstractString) = codeunits(x)
_ncodeunits(x::Base.CodeUnits) = length(x)
_codeunits(x::Base.CodeUnits) = x

# an AbstractVector version of Iterators.flatten
# code based on SentinelArrays.ChainedVector
struct ToList{T, stringtype, A, I} <: AbstractVector{T}
Expand All @@ -74,14 +89,21 @@ struct ToList{T, stringtype, A, I} <: AbstractVector{T}
end

origtype(::ToList{T, S, A, I}) where {T, S, A, I} = A
liststringtype(::Type{ToList{T, S, A, I}}) where {T, S, A, I} = S
function liststringtype(::List{T, O, A}) where {T, O, A}
ST = Base.nonmissingtype(T)
K = ArrowTypes.ArrowKind(ST)
return liststringtype(A) || ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
end
liststringtype(T) = false

function ToList(input; largelists::Bool=false)
AT = eltype(input)
ST = Base.nonmissingtype(AT)
K = ArrowTypes.ArrowKind(ST)
stringtype = ArrowTypes.isstringtype(K)
stringtype = ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
T = stringtype ? UInt8 : eltype(ST)
len = stringtype ? ncodeunits : length
len = stringtype ? _ncodeunits : length
data = AT[]
I = largelists ? Int64 : Int32
inds = I[0]
Expand Down Expand Up @@ -122,15 +144,15 @@ Base.@propagate_inbounds function Base.getindex(A::ToList{T, stringtype}, i::Int
@boundscheck checkbounds(A, i)
chunk, ix = index(A, i)
@inbounds x = A.data[chunk]
return @inbounds stringtype ? codeunits(x)[ix] : x[ix]
return @inbounds stringtype ? _codeunits(x)[ix] : x[ix]
end

Base.@propagate_inbounds function Base.setindex!(A::ToList{T, stringtype}, v, i::Integer) where {T, stringtype}
@boundscheck checkbounds(A, i)
chunk, ix = index(A, i)
@inbounds x = A.data[chunk]
if stringtype
codeunits(x)[ix] = v
_codeunits(x)[ix] = v
else
x[ix] = v
end
Expand All @@ -149,7 +171,7 @@ end
chunk_len = A.inds[chunk]
end
val = A.data[chunk - 1]
x = stringtype ? codeunits(val)[1] : val[1]
x = stringtype ? _codeunits(val)[1] : val[1]
# find next valid index
i += 1
if i > chunk_len
Expand All @@ -168,7 +190,7 @@ end
@inline function Base.iterate(A::ToList{T, stringtype}, (i, chunk, chunk_i, chunk_len, len)) where {T, stringtype}
i > len && return nothing
@inbounds val = A.data[chunk - 1]
@inbounds x = stringtype ? codeunits(val)[chunk_i] : val[chunk_i]
@inbounds x = stringtype ? _codeunits(val)[chunk_i] : val[chunk_i]
i += 1
if i > chunk_len
chunk_i = 1
Expand All @@ -191,7 +213,7 @@ function arrowvector(::ListKind, x, i, nl, fi, de, ded, meta; largelists::Bool=f
validity = ValidityBitmap(x)
flat = ToList(x; largelists=largelists)
offsets = Offsets(UInt8[], flat.inds)
if eltype(flat) == UInt8 # binary or utf8string
if liststringtype(typeof(flat)) && eltype(flat) == UInt8 # binary or utf8string
data = flat
T = origtype(flat)
else
Expand All @@ -208,7 +230,7 @@ function compress(Z::Meta.CompressionType, comp, x::List{T, O, A}) where {T, O,
offsets = compress(Z, comp, x.offsets.offsets)
buffers = [validity, offsets]
children = Compressed[]
if eltype(A) == UInt8
if liststringtype(x)
push!(buffers, compress(Z, comp, x.data))
else
push!(children, compress(Z, comp, x.data))
Expand Down
4 changes: 2 additions & 2 deletions src/arraytypes/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function makenodesbuffers!(col::Union{Map{T, O, A}, List{T, O, A}}, fieldnodes,
push!(fieldbuffers, Buffer(bufferoffset, blen))
@debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
bufferoffset += padding(blen, alignment)
if eltype(A) == UInt8
if liststringtype(col)
blen = length(col.data)
push!(fieldbuffers, Buffer(bufferoffset, blen))
@debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
Expand All @@ -110,7 +110,7 @@ function writebuffer(io, col::Union{Map{T, O, A}, List{T, O, A}}, alignment) whe
@debugv 1 "writing array: col = $(typeof(col.offsets.offsets)), n = $n, padded = $(padding(n, alignment))"
writezeros(io, paddinglength(n, alignment))
# write values array
if eltype(A) == UInt8
if liststringtype(col)
n = writearray(io, UInt8, col.data)
@debugv 1 "writing array: col = $(typeof(col.data)), n = $n, padded = $(padding(n, alignment))"
writezeros(io, paddinglength(n, alignment))
Expand Down
6 changes: 3 additions & 3 deletions src/eltypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ juliaeltype(f::Meta.Field, b::Union{Meta.Utf8, Meta.LargeUtf8}, convert) = Strin
datasizeof(x) = sizeof(x)
datasizeof(x::AbstractVector) = sum(datasizeof, x)

juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Vector{UInt8}
juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Base.CodeUnits

juliaeltype(f::Meta.Field, x::Meta.FixedSizeBinary, convert) = NTuple{Int(x.byteWidth), UInt8}

Expand Down Expand Up @@ -393,7 +393,7 @@ end

# arrowtype will call fieldoffset recursively for children
function arrowtype(b, x::List{T, O, A}) where {T, O, A}
if eltype(A) == UInt8
if liststringtype(x)
if T <: AbstractString || T <: Union{AbstractString, Missing}
if O == Int32
Meta.utf8Start(b)
Expand All @@ -402,7 +402,7 @@ function arrowtype(b, x::List{T, O, A}) where {T, O, A}
Meta.largUtf8Start(b)
return Meta.LargeUtf8, Meta.largUtf8End(b), nothing
end
else # if Vector{UInt8}
else # if Base.CodeUnits
if O == Int32
Meta.binaryStart(b)
return Meta.Binary, Meta.binaryEnd(b), nothing
Expand Down
36 changes: 36 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,42 @@ for (col1, col2) in zip(Tables.columns(df), Tables.columns(df_load))
@test col1 == col2
end

@testset "# 411" begin
# Vector{UInt8} are written as List{UInt8} in Arrow
# Base.CodeUnits are written as Binary
t = (
a=[[0x00, 0x01], UInt8[], [0x03]],
am=[[0x00, 0x01], [0x03], missing],
b=[b"01", b"", b"3"],
bm=[b"01", b"3", missing],
c=["a", "b", "c"],
cm=["a", "c", missing]
)
buf = Arrow.tobuffer(t)
tt = Arrow.Table(buf)
@test t.a == tt.a
@test isequal(t.am, tt.am)
@test t.b == tt.b
@test isequal(t.bm, tt.bm)
@test t.c == tt.c
@test isequal(t.cm, tt.cm)
@test Arrow.schema(tt)[].fields[1].type isa Arrow.Flatbuf.List
@test Arrow.schema(tt)[].fields[3].type isa Arrow.Flatbuf.Binary
pos = position(buf)
Arrow.append(buf, tt)
seekstart(buf)
buf1 = read(buf, pos)
buf2 = read(buf)
t1 = Arrow.Table(buf1)
t2 = Arrow.Table(buf2)
@test isequal(t1.a, t2.a)
@test isequal(t1.am, t2.am)
@test isequal(t1.b, t2.b)
@test isequal(t1.bm, t2.bm)
@test isequal(t1.c, t2.c)
@test isequal(t1.cm, t2.cm)

end

end # @testset "misc"

Expand Down

0 comments on commit 899ecb0

Please sign in to comment.