Skip to content

Commit

Permalink
Merge pull request #4 from jdlangs/quotenode_expr_compat
Browse files Browse the repository at this point in the history
Allow diff rule function to be in an expression or QuoteNode
  • Loading branch information
jrevels authored Nov 2, 2017
2 parents e36b1fd + 3762b1e commit 823e36c
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct DiffRule{M,f} end
Define a new differentiation rule for the function `M.f` and the given arguments, which should
be treated as bindings to Julia expressions.
The RHS should be a function call with a non-splatted argument list, and the LHS should be
The LHS should be a function call with a non-splatted argument list, and the RHS should be
the derivative expression, or in the `n`-ary case, an `n`-tuple of expressions where the
`i`th expression is the derivative of `f` w.r.t the `i`th argument. Arguments should be
interpolated wherever they are used on the RHS.
Expand All @@ -27,17 +27,16 @@ Examples:
"""
macro define_diffrule(def)
@assert isa(def, Expr) && def.head == :(=)
rhs = def.args[1]
lhs = def.args[2]
@assert isa(rhs, Expr) && rhs.head == :call
qualified_f = rhs.args[1]
@assert isa(qualified_f, Expr) && qualified_f.head == :(.)
@assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side"
lhs = def.args[1]
rhs = def.args[2]
@assert isa(lhs, Expr) && lhs.head == :call "LHS is not a function call"
qualified_f = lhs.args[1]
@assert isa(qualified_f, Expr) && qualified_f.head == :(.) "Function is not qualified by module"
M, quoted_f = qualified_f.args
@assert isa(quoted_f, Expr) && quoted_f.head == :quote
f = first(quoted_f.args)
args = rhs.args[2:end]
rhs.args[1] = :(::Type{$DiffRules.DiffRule{$(Expr(:quote, M)),$(Expr(:quote, f))}})
f = _get_quoted_symbol(quoted_f)
args = lhs.args[2:end]
lhs.args[1] = :(::Type{$DiffRules.DiffRule{$(Expr(:quote, M)),$(Expr(:quote, f))}})
key = (M, f, length(args))
in(DEFINED_DIFFRULES, key) || push!(DEFINED_DIFFRULES, key)
return esc(def)
Expand Down Expand Up @@ -111,3 +110,17 @@ Examples:
"""
diffrules() = DEFINED_DIFFRULES

#For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a
#`Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will always enter
#in a `QuoteNode` (#23885).
function _get_quoted_symbol(ex::Expr)
@assert ex.head == :quote
@assert length(ex.args) == 1 && isa(ex.args[1], Symbol) "Function not a single symbol"
ex.args[1]
end

function _get_quoted_symbol(ex::QuoteNode)
@assert isa(ex.value, Symbol) "Function not a single symbol"
ex.value
end

0 comments on commit 823e36c

Please sign in to comment.