Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make @non_differentiable use identical pullbacks when possible #679

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
end
end

struct NonDiffPullback{T<:Tuple{Vararg{NoTangent}}} <: Function
v::T
end

function (@nospecialize pb::NonDiffPullback)(@nospecialize ::Any)
return pb.v
end

function tuple_expression(primal_sig_parts)
has_vararg = _isvararg(primal_sig_parts[end])
return if !has_vararg
Expand All @@ -436,9 +444,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
tup_expr = tuple_expression(primal_sig_parts)
primal_name = first(primal_invoke.args)
pullback_expr = @strip_linenos quote
function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_))
return $(tup_expr)
end
NonDiffPullback($(tup_expr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess even simpler would be

Suggested change
NonDiffPullback($(tup_expr))
Returns($(tup_expr))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will change the method signature, though. Currently the pullback accepts a single argument, while Returns accepts any amount of arguments. Is that OK to change?

end

@gensym kwargs
Expand Down
26 changes: 26 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,32 @@ end

@testset "rule_definition_tools.jl" begin
@testset "@non_differentiable" begin
@testset "`NonDiffPullback`" begin
NDP = ChainRulesCore.NonDiffPullback
for i in 0:5
tup = ntuple((_ -> NoTangent()), i)
ndp = NDP(tup)
@test ndp === @inferred NDP(tup)
@test tup === @inferred ndp(:arbitrary)
@test_throws MethodError ndp()
@test_throws MethodError ndp(1, 2)
end
end

@testset "issue #678: identical pullback objects" begin
issue_678_f(::Any) = nothing
issue_678_g(::Any) = nothing
issue_678_h(::Any...) = nothing
@non_differentiable issue_678_f(::Any)
@non_differentiable issue_678_g(::Any)
@non_differentiable issue_678_h(::Any...)
@test (
last(rrule(issue_678_f, 0.1)) ===
last(rrule(issue_678_g, 0.2)) ===
last(rrule(issue_678_h, 0.3))
)
end

@testset "two input one output function" begin
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@non_differentiable nondiff_2_1(::Any, ::Any)
Expand Down
Loading