diff --git a/Project.toml b/Project.toml index aa26c0f3c..bad5f567f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" +version = "1.24.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 286f71db2..544b3d9c6 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using Base.Meta using LinearAlgebra -using Compat: hasfield, hasproperty, ismutabletype +using Compat: hasfield, hasproperty, ismutabletype, Returns export frule, rrule # core function # rule configurations diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 0453d6368..88ec3e8aa 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -436,9 +436,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) tup_expr = tuple_expression(primal_sig_parts) primal_name = first(primal_invoke.args) pullback_expr = @strip_linenos quote - function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_)) - return $(tup_expr) - end + Returns($(tup_expr)) end @gensym kwargs diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 43863a915..de31941ed 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -42,6 +42,20 @@ end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin + @testset "issue #678: identical pullback objects" begin + issue_678_f(::Any) = nothing + issue_678_g(::Any) = nothing + issue_678_h(::Any...) = nothing + @non_differentiable issue_678_f(::Any) + @non_differentiable issue_678_g(::Any) + @non_differentiable issue_678_h(::Any...) + @test ( + last(rrule(issue_678_f, 0.1)) === + last(rrule(issue_678_g, 0.2)) === + last(rrule(issue_678_h, 0.3)) + ) + end + @testset "two input one output function" begin nondiff_2_1(x, y) = fill(7.5, 100)[x + y] @non_differentiable nondiff_2_1(::Any, ::Any)