Skip to content

Commit

Permalink
Allow diff rule function to be in an expression or QuoteNode
Browse files Browse the repository at this point in the history
Julia 0.7 changes the macro argument parsing to consistently use QuoteNodes instead of
Expr(:quote...) objects (#23885).

This also fixes some swapped LHS/RHS uses in the docs/variables.
  • Loading branch information
jdlangs committed Oct 29, 2017
1 parent e36b1fd commit 3762b1e
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct DiffRule{M,f} end
Define a new differentiation rule for the function `M.f` and the given arguments, which should
be treated as bindings to Julia expressions.
The RHS should be a function call with a non-splatted argument list, and the LHS should be
The LHS should be a function call with a non-splatted argument list, and the RHS should be
the derivative expression, or in the `n`-ary case, an `n`-tuple of expressions where the
`i`th expression is the derivative of `f` w.r.t the `i`th argument. Arguments should be
interpolated wherever they are used on the RHS.
Expand All @@ -27,17 +27,16 @@ Examples:
"""
macro define_diffrule(def)
@assert isa(def, Expr) && def.head == :(=)
rhs = def.args[1]
lhs = def.args[2]
@assert isa(rhs, Expr) && rhs.head == :call
qualified_f = rhs.args[1]
@assert isa(qualified_f, Expr) && qualified_f.head == :(.)
@assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side"
lhs = def.args[1]
rhs = def.args[2]
@assert isa(lhs, Expr) && lhs.head == :call "LHS is not a function call"
qualified_f = lhs.args[1]
@assert isa(qualified_f, Expr) && qualified_f.head == :(.) "Function is not qualified by module"
M, quoted_f = qualified_f.args
@assert isa(quoted_f, Expr) && quoted_f.head == :quote
f = first(quoted_f.args)
args = rhs.args[2:end]
rhs.args[1] = :(::Type{$DiffRules.DiffRule{$(Expr(:quote, M)),$(Expr(:quote, f))}})
f = _get_quoted_symbol(quoted_f)
args = lhs.args[2:end]
lhs.args[1] = :(::Type{$DiffRules.DiffRule{$(Expr(:quote, M)),$(Expr(:quote, f))}})
key = (M, f, length(args))
in(DEFINED_DIFFRULES, key) || push!(DEFINED_DIFFRULES, key)
return esc(def)
Expand Down Expand Up @@ -111,3 +110,17 @@ Examples:
"""
diffrules() = DEFINED_DIFFRULES

#For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a
#`Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will always enter
#in a `QuoteNode` (#23885).
function _get_quoted_symbol(ex::Expr)
@assert ex.head == :quote
@assert length(ex.args) == 1 && isa(ex.args[1], Symbol) "Function not a single symbol"
ex.args[1]
end

function _get_quoted_symbol(ex::QuoteNode)
@assert isa(ex.value, Symbol) "Function not a single symbol"
ex.value
end

0 comments on commit 3762b1e

Please sign in to comment.