From aafb6644b6b94b9efa2c0a6761985cd23feb47ee Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Thu, 30 May 2024 15:44:53 +0200 Subject: [PATCH 1/5] make `rrule` return identical pullback for `zero` as for `one` Could be a minor compilation latency and/or type stability win for some uses. --- src/rulesets/Base/base.jl | 11 +++++++---- test/rulesets/Base/base.jl | 1 + 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6c66d19ee..a7ebdcbb1 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -4,6 +4,11 @@ @scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent()) @scalar_rule transpose(x) true +# TODO: define using `Returns((NoTangent(), ZeroTangent()))` when support for Julia v1.6 is dropped +function _pullback_for_constant(::Any) + (NoTangent(), ZeroTangent()) +end + # `zero` function frule((_, _), ::typeof(zero), x) @@ -11,8 +16,7 @@ function frule((_, _), ::typeof(zero), x) end function rrule(::typeof(zero), x) - zero_pullback(_) = (NoTangent(), ZeroTangent()) - return (zero(x), zero_pullback) + return (zero(x), _pullback_for_constant) end # `one` @@ -22,8 +26,7 @@ function frule((_, _), ::typeof(one), x) end function rrule(::typeof(one), x) - one_pullback(_) = (NoTangent(), ZeroTangent()) - return (one(x), one_pullback) + return (one(x), _pullback_for_constant) end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 25c755f55..0b9e0a721 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -4,6 +4,7 @@ end @testset "base.jl" begin @testset "zero/one" begin + @test last(rrule(zero, 0.1)) === last(rrule(one, 0.2f0)) for f in [zero, one] for x in [1.0, 1.0im, [10.0+im 11.0-im; 12.0+2im 13.0-3im]] test_frule(f, x) From 21ec79072b0e4552483fda835232d0c6149bf237 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Thu, 30 May 2024 16:03:04 +0200 Subject: [PATCH 2/5] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7732e7d1d..c4c871232 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.66.0" +version = "1.66.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From ab61772880e0f90c480219d9b0f56da25cc075c9 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Thu, 30 May 2024 16:06:39 +0200 Subject: [PATCH 3/5] stylistic: add `return` keyword --- src/rulesets/Base/base.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index a7ebdcbb1..97d37a91b 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -6,7 +6,7 @@ # TODO: define using `Returns((NoTangent(), ZeroTangent()))` when support for Julia v1.6 is dropped function _pullback_for_constant(::Any) - (NoTangent(), ZeroTangent()) + return (NoTangent(), ZeroTangent()) end # `zero` From 3169a33cbfd65921ea4f68b06b22abfbace4309b Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Thu, 30 May 2024 18:37:32 +0200 Subject: [PATCH 4/5] we can use `Returns` after all thanks to Compat.jl --- src/rulesets/Base/base.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 97d37a91b..6673b200a 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -4,10 +4,7 @@ @scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent()) @scalar_rule transpose(x) true -# TODO: define using `Returns((NoTangent(), ZeroTangent()))` when support for Julia v1.6 is dropped -function _pullback_for_constant(::Any) - return (NoTangent(), ZeroTangent()) -end +const _pullback_for_constant = Returns((NoTangent(), ZeroTangent())) # `zero` From 363aa6c1dc84a8de6d3f14364f5a1f456b8ba755 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Fri, 31 May 2024 13:22:11 +0200 Subject: [PATCH 5/5] simpler --- Project.toml | 2 +- src/rulesets/Base/base.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index c4c871232..5d88267fe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.66.1" +version = "1.67.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6673b200a..8c616345a 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -4,8 +4,6 @@ @scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent()) @scalar_rule transpose(x) true -const _pullback_for_constant = Returns((NoTangent(), ZeroTangent())) - # `zero` function frule((_, _), ::typeof(zero), x) @@ -13,7 +11,8 @@ function frule((_, _), ::typeof(zero), x) end function rrule(::typeof(zero), x) - return (zero(x), _pullback_for_constant) + zero_pullback = Returns((NoTangent(), ZeroTangent())) + return (zero(x), zero_pullback) end # `one` @@ -23,7 +22,8 @@ function frule((_, _), ::typeof(one), x) end function rrule(::typeof(one), x) - return (one(x), _pullback_for_constant) + one_pullback = Returns((NoTangent(), ZeroTangent())) + return (one(x), one_pullback) end