Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Containers] add support for Base.getindex(::Container; kwargs...) #3237

Merged
merged 29 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/src/manual/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ julia> DataFrames.DataFrame(table)
4 │ 2 B (2, :B)
```

### Keyword indexing

If all axes are named, you can use keyword indexing:

```jldoctest containers_dense
julia> x[i = 2, j = :A]
(2, :A)

julia> x[i = :, j = :B]
1-dimensional DenseAxisArray{Tuple{Int64, Symbol},1,...} with index sets:
Dimension 1, Base.OneTo(2)
And data, a 2-element Vector{Tuple{Int64, Symbol}}:
(1, :B)
(2, :B)
```

## SparseAxisArray

A [`Containers.SparseAxisArray`](@ref) is created when the index sets are
Expand Down Expand Up @@ -352,6 +368,20 @@ julia> DataFrames.DataFrame(table)
4 │ 3 A (3, :A)
```

### Keyword indexing

If all axes are named, you can use keyword indexing:

```jldoctest containers_sparse
julia> x[i = 2, j = :A]
(2, :A)

julia> x[i = :, j = :B]
JuMP.Containers.SparseAxisArray{Tuple{Int64, Symbol}, 1, Tuple{Int64}} with 2 entries:
[2] = (2, :B)
[3] = (3, :B)
```

## Forcing the container type

Pass `container = T` to use `T` as the container. For example:
Expand Down
1 change: 1 addition & 0 deletions docs/src/reference/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Containers.DenseAxisArray
Containers.SparseAxisArray
Containers.container
Containers.rowtable
Containers.AutoContainerType
Containers.default_container
Containers.@container
Containers.VectorizedProductIterator
Expand Down
26 changes: 26 additions & 0 deletions docs/src/tutorials/getting_started/getting_started_with_JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ end #hide

@variable(model, a[1:2, 1:2])

# Index elements in `a` as follows:

a[1, 1]

#-

a[2, :]

# Create an n-dimensional variable $x \in {R}^n$ with bounds $l \le x \le u$
# ($l, u \in {R}^n$) as follows:

Expand All @@ -352,6 +360,16 @@ u = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19];

@variable(model, w[1:5, ["red", "blue"]] <= 1)

# Index elements in a `DenseAxisArray` as follows:

z[2, 1]

#-

w[2:3, ["red", "blue"]]

# See [Forcing the container type](@ref variable_forcing) for more details.

# #### SparseAxisArrays

# `SparseAxisArrays` are created when the indices do not form a Cartesian product.
Expand All @@ -366,6 +384,14 @@ u = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19];

@variable(model, v[i = 1:9; mod(i, 3) == 0])

# Index elements in a `DenseAxisArray` as follows:

u[1, 2]

#-

v[[3, 6]]

# ### Integrality

# JuMP can create binary and integer variables. Binary variables are constrained
Expand Down
158 changes: 128 additions & 30 deletions src/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct DenseAxisArray{T,N,Ax,L<:NTuple{N,_AxisLookup}} <: AbstractArray{T,N}
data::Array{T,N}
axes::Ax
lookup::L
names::NTuple{N,Symbol}
end

function Base.Array{T,N}(x::DenseAxisArray) where {T,N}
Expand Down Expand Up @@ -200,15 +201,20 @@ julia> array[:b, 3]
4
```
"""
function DenseAxisArray(data::Array{T,N}, axs...) where {T,N}
@assert length(axs) == N
new_axes = _abstract_vector.(axs) # Force all axes to be AbstractVector!
return DenseAxisArray(data, new_axes, build_lookup.(new_axes))
function DenseAxisArray(
data::Array{T,N},
axes...;
names::Union{Nothing,NTuple{N,Symbol}} = nothing,
) where {T,N}
@assert length(axes) == N
new_axes = _abstract_vector.(axes) # Force all axes to be AbstractVector!
names = something(names, ntuple(n -> Symbol("#$n"), N))
return DenseAxisArray(data, new_axes, build_lookup.(new_axes), names)
end

# A converter for different array types.
function DenseAxisArray(data::AbstractArray, axes...)
return DenseAxisArray(collect(data), axes...)
function DenseAxisArray(data::AbstractArray, axes...; kwargs...)
return DenseAxisArray(collect(data), axes...; kwargs...)
end

"""
Expand Down Expand Up @@ -245,12 +251,17 @@ And data, a 2×2 Array{Float64,2}:
1.0 1.0
```
"""
function DenseAxisArray{T}(::UndefInitializer, axs...) where {T}
return construct_undef_array(T, axs)
function DenseAxisArray{T}(::UndefInitializer, args...; kwargs...) where {T}
return construct_undef_array(T, args; kwargs...)
end

function construct_undef_array(::Type{T}, axs::Tuple{Vararg{Any,N}}) where {T,N}
return DenseAxisArray(Array{T,N}(undef, length.(axs)...), axs...)
function construct_undef_array(
::Type{T},
args::Tuple{Vararg{Any,N}};
kwargs...,
) where {T,N}
data = Array{T,N}(undef, length.(args)...)
return DenseAxisArray(data, args...; kwargs...)
end

Base.isempty(A::DenseAxisArray) = isempty(A.data)
Expand Down Expand Up @@ -344,19 +355,51 @@ end
_is_range(::Any) = false
_is_range(::Union{Vector{Int},Colon}) = true

function Base.getindex(A::DenseAxisArray{T,N}, idx...) where {T,N}
new_indices = Base.to_index(A, idx)
function _kwargs_to_args(A::DenseAxisArray{T,N}; kwargs...) where {T,N}
return ntuple(N) do i
kw = keys(kwargs)[i]
if A.names[i] != kw
error(
"Invalid index $kw in position $i. When using keyword " *
"indexing, the indices must match the exact name and order " *
"used when creating the container.",
)
end
return kwargs[i]
end
end

function Base.getindex(A::DenseAxisArray{T,N}, args...; kwargs...) where {T,N}
if !isempty(kwargs)
if !isempty(args)
error("Cannot index with mix of positional and keyword arguments")
end
return getindex(A, _kwargs_to_args(A; kwargs...)...)
end
new_indices = Base.to_index(A, args)
if !any(_is_range, new_indices)
return A.data[new_indices...]::T
end
new_axes = _getindex_recurse(A.axes, new_indices, _is_range)
return DenseAxisArray(A.data[new_indices...], new_axes...)
names = A.names[findall(_is_range, new_indices)]
return DenseAxisArray(A.data[new_indices...], new_axes...; names = names)
end

Base.getindex(A::DenseAxisArray, idx::CartesianIndex) = A.data[idx]

function Base.setindex!(A::DenseAxisArray{T,N}, v, idx...) where {T,N}
return A.data[Base.to_index(A, idx)...] = v
function Base.setindex!(
A::DenseAxisArray{T,N},
v,
args...;
kwargs...,
) where {T,N}
if !isempty(kwargs)
if !isempty(args)
error("Cannot index with mix of positional and keyword arguments")
end
return setindex!(A, v, _kwargs_to_args(A; kwargs...)...)
end
return A.data[Base.to_index(A, args)...] = v
end

function Base.setindex!(
Expand Down Expand Up @@ -466,26 +509,37 @@ function _broadcast_axes_check(x::NTuple{N}) where {N}
return axes
end

_broadcast_axes(x::Tuple) = _broadcast_axes(first(x), Base.tail(x))
_broadcast_axes(::Tuple{}) = ()
_broadcast_axes(::Any, tail) = _broadcast_axes(tail)
function _broadcast_axes(x::DenseAxisArray, tail)
return ((x.axes, x.lookup), _broadcast_axes(tail)...)
_broadcast_args(f, x::Tuple) = _broadcast_args(f, first(x), Base.tail(x))

_broadcast_args(f, ::Tuple{}) = ()

_broadcast_args(f::Val{:axes}, x::Any, tail) = _broadcast_args(f, tail)

function _broadcast_args(f::Val{:axes}, x::DenseAxisArray, tail)
return ((x.axes, x.lookup), _broadcast_args(f, tail)...)
end

_broadcast_args(f::Val{:data}, x::Any, tail) = (x, _broadcast_args(f, tail)...)

function _broadcast_args(f::Val{:data}, x::DenseAxisArray, tail)
return (x.data, _broadcast_args(f, tail)...)
end

_broadcast_args(x::Tuple) = _broadcast_args(first(x), Base.tail(x))
_broadcast_args(::Tuple{}) = ()
_broadcast_args(x::Any, tail) = (x, _broadcast_args(tail)...)
_broadcast_args(x::DenseAxisArray, tail) = (x.data, _broadcast_args(tail)...)
_broadcast_args(f::Val{:names}, x::Any, tail) = _broadcast_args(f, tail)

function _broadcast_args(f::Val{:names}, x::DenseAxisArray, tail)
return (x.names, _broadcast_args(f, tail)...)
end

function Base.Broadcast.broadcasted(
::Broadcast.ArrayStyle{DenseAxisArray},
f,
args...,
)
axes_lookup = _broadcast_axes_check(_broadcast_axes(args))
new_args = _broadcast_args(args)
return DenseAxisArray(broadcast(f, new_args...), axes_lookup...)
axes, lookup = _broadcast_axes_check(_broadcast_args(Val(:axes), args))
new_args = _broadcast_args(Val(:data), args)
names = _broadcast_args(Val(:names), args)
return DenseAxisArray(broadcast(f, new_args...), axes, lookup, first(names))
end

########
Expand Down Expand Up @@ -650,12 +704,18 @@ end

Base.axes(x::DenseAxisArrayView) = _type_stable_axes(x.axes)

_is_subaxis(key::K, axis::AbstractVector{K}) where {K} = key in axis

function _is_subaxis(key::AbstractVector{K}, axis::AbstractVector{K}) where {K}
return all(k -> k in axis, key)
end

function _type_stable_args(axis::AbstractVector, ::Colon, axes, args)
return (axis, _type_stable_args(axes, args)...)
end

function _type_stable_args(axis::AbstractVector, arg, axes, args)
if !(arg in axis)
if !_is_subaxis(arg, axis)
throw(KeyError(arg))
end
return (arg, _type_stable_args(axes, args)...)
Expand All @@ -676,7 +736,34 @@ end

_type_stable_args(axes::Tuple, ::Tuple{}) = axes

function Base.getindex(x::DenseAxisArrayView, args...)
function _fixed_indices(view_axes::Tuple, axes::Tuple)
return filter(ntuple(i -> i, length(view_axes))) do i
return !(typeof(view_axes[i]) <: eltype(axes[i]))
end
end

function _kwargs_to_args(A::DenseAxisArrayView{T,N}; kwargs...) where {T,N}
non_default_indices = _fixed_indices(A.axes, A.data.axes)
return ntuple(N) do i
kw = keys(kwargs)[i]
if A.data.names[non_default_indices[i]] != kw
error(
"Invalid index $kw in position $i. When using keyword " *
"indexing, the indices must match the exact name and order " *
"used when creating the container.",
)
end
return kwargs[i]
end
end

function Base.getindex(x::DenseAxisArrayView, args...; kwargs...)
if !isempty(kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getindex is performance sensitive. Is there any impact from this extra !isempty(kwargs) in the case of all positional arguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. It gets compiled away:

shell> git status
On branch od/container-kwarg-getindex
Your branch is up to date with 'origin/od/container-kwarg-getindex'.

nothing to commit, working tree clean

julia> using JuMP

julia> using BenchmarkTools

julia> model = Model()
A JuMP Model
Feasibility problem with:
Variables: 0
Model mode: AUTOMATIC
CachingOptimizer state: NO_OPTIMIZER
Solver name: No optimizer attached.

julia> @variable(model, x[i=2:3, j=4:5])
2-dimensional DenseAxisArray{VariableRef,2,...} with index sets:
    Dimension 1, 2:3
    Dimension 2, 4:5
And data, a 2×2 Matrix{VariableRef}:
 x[2,4]  x[2,5]
 x[3,4]  x[3,5]

julia> @btime getindex($x, 2, 4)
  3.466 ns (0 allocations: 0 bytes)
x[2,4]

julia> exit()
(base) oscar@Oscars-MBP JuMP % git checkout master
Switched to branch 'master'
Your branch is up to date with 'origin/master'.
(base) oscar@Oscars-MBP JuMP % julia --project=.  
               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.6.7 (2022-07-19)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

julia> using JuMP
[ Info: Precompiling JuMP [4076af6c-e467-56ae-b986-b466b2749572]

julia> using BenchmarkTools

julia> model = Model()
A JuMP Model
Feasibility problem with:
Variables: 0
Model mode: AUTOMATIC
CachingOptimizer state: NO_OPTIMIZER
Solver name: No optimizer attached.

julia> @variable(model, x[i=2:3, j=4:5])
2-dimensional DenseAxisArray{VariableRef,2,...} with index sets:
    Dimension 1, 2:3
    Dimension 2, 4:5
And data, a 2×2 Matrix{VariableRef}:
 x[2,4]  x[2,5]
 x[3,4]  x[3,5]

julia> @btime getindex($x, 2, 4)
  3.725 ns (0 allocations: 0 bytes)
x[2,4]

There's still room to improve the kwarg option, but let's settle on the syntax/opt-in stuff first:

julia> @btime getindex($x; i = 2, j = 4)
  33.646 ns (1 allocation: 32 bytes)
x[2,4]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is kwarg indexing fundamentally 10x slower than positional indexing? If that's true we might want to reconsider the syntax.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if you change the method to a singular getindex(x; kwargs...) it comes down to 12ns and 0 allocations, but then I needed to add a bunch of additional methods to avoid ambiguities. Left as-is for now while we sort the syntax.

Keyword indexing in general does have a slight performance hit though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keyword indexing in general does have a slight performance hit though.

In fact, keyword indexing can be as fast as Cartesian indexing if the tuple of names is encoded into the customized array type, but it is obviously not an option for JuMP now.

if !isempty(args)
error("Cannot index with mix of positional and keyword arguments")
end
return getindex(x, _kwargs_to_args(x; kwargs...)...)
end
indices = _type_stable_args(x.axes, args)
return getindex(x.data, indices...)
end
Expand All @@ -691,7 +778,18 @@ function Base.setindex!(
return setindex!(a, value, k.I...)
end

function Base.setindex!(x::DenseAxisArrayView{T}, value::T, args...) where {T}
function Base.setindex!(
x::DenseAxisArrayView{T},
value::T,
args...;
kwargs...,
) where {T}
if !isempty(kwargs)
if !isempty(args)
error("Cannot index with mix of positional and keyword arguments")
end
return setindex!(x, value, _kwargs_to_args(x; kwargs...)...)
end
indices = _type_stable_args(x.axes, args)
return setindex!(x.data, value, indices...)
end
Expand Down
Loading