Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved and simplified BinaryOperation with "stubborn" location inference #1599

Merged
merged 15 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/AbstractOperations/AbstractOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ Base.parent(op::AbstractOperation) = op
# AbstractOperation macros add their associated functions to this list
const operators = Set()

include("at.jl")
include("grid_validation.jl")

include("unary_operations.jl")
Expand Down Expand Up @@ -78,4 +77,7 @@ eval(define_multiary_operator(:*))
push!(operators, :*)
push!(multiary_operators, :*)

include("at.jl")

end # module

18 changes: 17 additions & 1 deletion src/AbstractOperations/at.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ end
"Fallback for when `insert_location` is called on objects other than expressions."
insert_location!(anything, location) = nothing

# A very special UnaryOperation
@inbounds identity(i, j, k, grid, a::Number) = a
@inbounds identity(i, j, k, grid, a::AbstractField) = @inbounds a[i, j, k]

function interpolate_operation(L, x::AbstractField)
L == location(x) && return x # Don't interpolate unecessarily
return _unary_operation(L, identity, x, location(x), x.grid)
end

"""
@at location abstract_operation

Expand All @@ -30,5 +39,12 @@ Modify the `abstract_operation` so that it returns values at
"""
macro at(location, abstract_operation)
insert_location!(abstract_operation, location)
return esc(abstract_operation)

# We wrap it all in an interpolator to help "stubborn" binary operations
# arrive in the right place.
wrapped_operation = quote
interpolate_operation($(esc(location)), $(esc(abstract_operation)))
end

return wrapped_operation
end
72 changes: 36 additions & 36 deletions src/AbstractOperations/binary_operations.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,48 @@
const binary_operators = Set()

"""
BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G}

An abstract representation of a binary operation on `AbstractField`s.
"""
struct BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G}
struct BinaryOperation{X, Y, Z, O, A, B, IA, IB, G} <: AbstractOperation{X, Y, Z, G}
op :: O
a :: A
b :: B
▶a :: IA
▶b :: IB
▶op :: IΩ
grid :: G

"""
BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid)
BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, grid)

Returns an abstract representation of the binary operation `op(▶a(a), ▶b(b))`,
followed by interpolation by `▶op` to `(X, Y, Z)`, where `▶a` and `▶b` interpolate
`a` and `b` to a common location.
Returns an abstract representation of the binary operation `op(▶a(a), ▶b(b))`.
where `▶a` and `▶b` interpolate `a` and `b` to (X, Y, Z).
"""
function BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid) where {X, Y, Z}

any((X, Y, Z) .=== Nothing) && throw(ArgumentError("Nothing locations are invalid! " *
"Cannot construct BinaryOperation at ($X, $Y, $Z)."))

function BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, grid) where {X, Y, Z}
return new{X, Y, Z, typeof(op), typeof(a), typeof(b), typeof(▶a), typeof(▶b),
typeof(▶op), typeof(grid)}(op, a, b, ▶a, ▶b, ▶op, grid)
typeof(grid)}(op, a, b, ▶a, ▶b, grid)
end
end

@inline Base.getindex(β::BinaryOperation, i, j, k) = β.op(i, j, k, β.grid, β.op, β.▶a, β.▶b, β.a, β.b)
@inline Base.getindex(β::BinaryOperation, i, j, k) = β.op(i, j, k, β.grid, β.▶a, β.▶b, β.a, β.b)

#####
##### BinaryOperation construction
#####

"""Create a binary operation for `op` acting on `a` and `b` with locations `La` and `Lb`.
The operator acts at `Lab` and the result is interpolated to `Lc`."""
function _binary_operation(Lc, op, a, b, La, Lb, Lab, grid)
▶a = interpolation_operator(La, Lab)
▶b = interpolation_operator(Lb, Lab)
▶op = interpolation_operator(Lab, Lc)
return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, ▶op, grid)
function _binary_operation(Lc, op, a, b, La, Lb, grid)
▶a = interpolation_operator(La, Lc)
▶b = interpolation_operator(Lb, Lc)
return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, grid)
end

const ConcreteLocationType = Union{Type{Face}, Type{Center}}

# Precedence rules for choosing operation location:
choose_location(La, Lb, Lc) = Lc # Fallback to the specification Lc, but also...
choose_location(::Type{Face}, ::Type{Face}, Lc) = Face # keep common locations; and
choose_location(::Type{Center}, ::Type{Center}, Lc) = Center #
choose_location(La::ConcreteLocationType, ::Type{Nothing}, Lc) = La # don't interpolate unspecified locations.
choose_location(::Type{Nothing}, Lb::ConcreteLocationType, Lc) = Lb #

"""Return an expression that defines an abstract `BinaryOperator` named `op` for `AbstractField`."""
function define_binary_operator(op)
return quote
Expand All @@ -60,26 +57,29 @@ function define_binary_operator(op)
@inbounds $op(▶a(i, j, k, grid, a), ▶b(i, j, k, grid, b))

"""
$($op)(Lc, Lab, a, b)
$($op)(Lc, a, b)

Returns an abstract representation of the operator `$($op)` acting on `a` and `b` at
location `Lab`, and subsequently interpolated to location `Lc`.
Returns an abstract representation of the operator `$($op)` acting on `a` and `b`.
The operation occurs at location(a) except for Nothing dimensions. In that case,
the location of the dimension in question is supplied either by location(b) or
if that is also Nothing, Lc.
"""
function $op(Lc::Tuple, Lop::Tuple, a, b)
function $op(Lc::Tuple, a, b)
La = location(a)
Lb = location(b)
Lab = choose_location.(La, Lb, Lc)

grid = Oceananigans.AbstractOperations.validate_grid(a, b)
return Oceananigans.AbstractOperations._binary_operation(Lc, $op, a, b, La, Lb, Lop, grid)

return Oceananigans.AbstractOperations._binary_operation(Lab, $op, a, b, La, Lb, grid)
end

$op(Lc::Tuple, a, b) = $op(Lc, Lc, a, b)
$op(Lc::Tuple, a::Number, b) = $op(Lc, location(b), a, b)
$op(Lc::Tuple, a, b::Number) = $op(Lc, location(a), a, b)
$op(Lc::Tuple, a::AF{X, Y, Z}, b::AF{X, Y, Z}) where {X, Y, Z} = $op(Lc, location(a), a, b)
# Numbers are not fields...
$op(Lc::Tuple, a::Number, b::Number) = $op(a, b)

# Sugar for mixing in functions of (x, y, z)
$op(Lc::Tuple, a::Function, b::AbstractField) = $op(Lc, FunctionField(Lc, a, b.grid), b)
$op(Lc::Tuple, a::AbstractField, b::Function) = $op(Lc, a, FunctionField(Lc, b, a.grid))
$op(Lc::Tuple, f::Function, b::AbstractField) = $op(Lc, FunctionField(location(b), f, b.grid), b)
$op(Lc::Tuple, a::AbstractField, f::Function) = $op(Lc, a, FunctionField(location(a), f, a.grid))

# Sugary versions with default locations
$op(a::AF, b::AF) = $op(location(a), a, b)
Expand Down Expand Up @@ -184,5 +184,5 @@ end
"Adapt `BinaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, binary::BinaryOperation{X, Y, Z}) where {X, Y, Z} =
BinaryOperation{X, Y, Z}(Adapt.adapt(to, binary.op), Adapt.adapt(to, binary.a), Adapt.adapt(to, binary.b),
Adapt.adapt(to, binary.▶a), Adapt.adapt(to, binary.▶b), Adapt.adapt(to, binary.▶op),
binary.grid)
Adapt.adapt(to, binary.▶a), Adapt.adapt(to, binary.▶b), Adapt.adapt(to, binary.grid))

7 changes: 1 addition & 6 deletions src/AbstractOperations/derivatives.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
using Oceananigans.Operators: interpolation_code

"""
Derivative{X, Y, Z, D, A, I, G} <: AbstractOperation{X, Y, Z, G}

An abstract representation of a derivative of an `AbstractField`.
"""
struct Derivative{X, Y, Z, D, A, I, G} <: AbstractOperation{X, Y, Z, G}
∂ :: D
arg :: A
Expand Down Expand Up @@ -122,4 +117,4 @@ compute_at!(∂::Derivative, time) = compute_at!(∂.arg, time)
"Adapt `Derivative` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, deriv::Derivative{X, Y, Z}) where {X, Y, Z} =
Derivative{X, Y, Z}(Adapt.adapt(to, deriv.∂), Adapt.adapt(to, deriv.arg),
Adapt.adapt(to, deriv.▶), deriv.grid)
Adapt.adapt(to, deriv.▶), Adapt.adapt(to, deriv.grid))
2 changes: 1 addition & 1 deletion src/AbstractOperations/multiary_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,4 @@ end
"Adapt `MultiaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, multiary::MultiaryOperation{X, Y, Z}) where {X, Y, Z} =
MultiaryOperation{X, Y, Z}(Adapt.adapt(to, multiary.op), Adapt.adapt(to, multiary.args),
Adapt.adapt(to, multiary.▶), multiary.grid)
Adapt.adapt(to, multiary.▶), Adapt.adapt(to, multiary.grid))
2 changes: 1 addition & 1 deletion src/AbstractOperations/show_abstract_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
function tree_show(binary::BinaryOperation{X, Y, Z}, depth, nesting) where {X, Y, Z}
padding = get_tree_padding(depth, nesting)

return string(binary.op, " at ", show_location(X, Y, Z), " via ", show_interp(binary.▶op), '\n',
return string(binary.op, " at ", show_location(X, Y, Z), '\n',
padding, "├── ", tree_show(binary.a, depth+1, nesting+1), '\n',
padding, "└── ", tree_show(binary.b, depth+1, nesting))
end
Expand Down
8 changes: 1 addition & 7 deletions src/AbstractOperations/unary_operations.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
const unary_operators = Set()

"""
UnaryOperation{X, Y, Z, O, A, I, G} <: AbstractOperation{X, Y, Z, G}

An abstract representation of a unary operation on an `AbstractField`; or a function
`f(x)` with on argument acting on `x::AbstractField`.
"""
struct UnaryOperation{X, Y, Z, O, A, I, G} <: AbstractOperation{X, Y, Z, G}
op :: O
arg :: A
Expand Down Expand Up @@ -131,4 +125,4 @@ compute_at!(υ::UnaryOperation, time) = compute_at!(υ.arg, time)
"Adapt `UnaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, unary::UnaryOperation{X, Y, Z}) where {X, Y, Z} =
UnaryOperation{X, Y, Z}(Adapt.adapt(to, unary.op), Adapt.adapt(to, unary.arg),
Adapt.adapt(to, unary.▶), unary.grid)
Adapt.adapt(to, unary.▶), Adapt.adapt(to, unary.grid))
2 changes: 1 addition & 1 deletion src/Fields/abstract_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end
##### AbstractField functionality
#####

@inline location(a) = nothing
@inline location(a) = (Nothing, Nothing, Nothing)

"Returns the location `(X, Y, Z)` of an `AbstractField{X, Y, Z}`."
@inline location(::AbstractField{X, Y, Z}) where {X, Y, Z} = (X, Y, Z) # note no instantiation
Expand Down
66 changes: 43 additions & 23 deletions test/test_abstract_operations_computed_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,6 @@ end

u, v, w, T, S = fields(model)

@test_throws ArgumentError @at (Nothing, Nothing, Center) T * S

for ϕ in (u, v, w, T, S)
for op in (sin, cos, sqrt, exp, tanh)
@test op(ϕ) isa UnaryOperation
Expand Down Expand Up @@ -597,6 +595,15 @@ end
@test compute_plus(model)
@test compute_minus(model)
@test compute_times(model)

# Basic compilation test for nested BinaryOperations...
u, v, w = model.velocities
@test try
compute!(ComputedField(u + v - w))
true
catch
false
end
end

@testset "Multiary computations [$FT, $(typeof(arch))]" begin
Expand Down Expand Up @@ -661,12 +668,7 @@ end
@info " Testing operations with AveragedField..."

T, S = model.tracers

TS = AveragedField(T * S, dims=(1, 2))

@test_throws ArgumentError @at (Nothing, Nothing, Center) T * S
@test_throws ArgumentError TS * S

@test operations_with_averaged_field(model)
end

Expand Down Expand Up @@ -695,26 +697,44 @@ end

@test computations_with_averaged_field_derivative(model)

# These don't work on the GPU right now
if arch isa CPU
@test computations_with_averaged_fields(model)
else
@test_skip computations_with_averaged_fields(model)
end
u, v, w = model.velocities

set!(model, enforce_incompressibility = false, u = (x, y, z) -> z, v = 2, w = 3)

# Two ways to compute turbulent kinetic energy
U = AveragedField(u, dims=(1, 2))
V = AveragedField(v, dims=(1, 2))

# Build up compilation tests incrementally...
u_prime = u - U
u_prime_ccc = @at (Center, Center, Center) u - U
u_prime_squared = (u - U)^2
u_prime_squared_ccc = @at (Center, Center, Center) (u - U)^2
horizontal_twice_tke = (u - U)^2 + (v - V)^2
horizontal_tke = ((u - U)^2 + (v - V)^2) / 2
horizontal_tke_ccc = @at (Center, Center, Center) ((u - U)^2 + (v - V)^2) / 2
twice_tke = (u - U)^2 + (v - V)^2 + w^2
tke = ((u - U)^2 + (v - V)^2 + w^2) / 2
tke_ccc = @at (Center, Center, Center) ((u - U)^2 + (v - V)^2 + w^2) / 2

@test try compute!(ComputedField(u_prime )); true; catch; false; end
@test try compute!(ComputedField(u_prime_ccc )); true; catch; false; end
@test try compute!(ComputedField(u_prime_squared )); true; catch; false; end
@test try compute!(ComputedField(u_prime_squared_ccc )); true; catch; false; end
@test try compute!(ComputedField(horizontal_twice_tke)); true; catch; false; end
@test try compute!(ComputedField(horizontal_tke )); true; catch; false; end
@test try compute!(ComputedField(twice_tke )); true; catch; false; end

@test try compute!(ComputedField(horizontal_tke_ccc )); true; catch; false; end
@test try compute!(ComputedField(tke )); true; catch; false; end
@test try compute!(ComputedField(tke_ccc )); true; catch; false; end

computed_tke = ComputedField(tke_ccc)
@test (compute!(computed_tke); all(interior(computed_tke)[2:3, 2:3, 2:3] .== 9/2))
end

@testset "Computations with ComputedFields [$FT, $(typeof(arch))]" begin
@info " Testing computations with ComputedField [$FT, $(typeof(arch))]..."

# Basic compilation test...
u, v, w = model.velocities
@test try
compute!(ComputedField(u + v - w))
true
catch
false
end

@test computations_with_computed_fields(model)
end

Expand Down