Skip to content

Commit

Permalink
Dot-broadcasting for short-circuiting ops .&& and .|| (JuliaLang#39594)
Browse files Browse the repository at this point in the history
I have long wanted a proper fix for issue JuliaLang#5187. It was the very first Julia issue I filed.
This is a shot at such a fix. This PR:

* Enables parsing for `.&&` and `.||`.  They are parsed into `Expr(:call, :.&&, ...)` expressions at the same precedence as their respective `&&` and `||`:
    ```julia-repl
    julia> Meta.show_sexpr(:(a .&& b))
    (:call, :.&&, :a, :b)
    ```

* Unlike all other dotted operators `.op` (like `.+`), the `op`-alone part (`var"&&"`) is not an exported name from Base. As such, this effectively lowers to `broadcasted((x,y)->x && y, ...)`, but instead of using an anonymous function I've named it `Base.andand` and `Base.oror`:
    ```julia-repl
    julia> Meta.@lower a .&& b
    :($(Expr(:thunk, CodeInfo(
        @ none within `top-level scope'
    1 ─ %1 = Base.broadcasted(Base.andand, a, b)
    │   %2 = Base.materialize(%1)
    └──      return %2
    ))))
    ```

* I've used a named function to enable short-circuiting behavior _within the broadcast kernel itself_. In the case that the second argument is a part of the same fused broadcast kernel, it will only evaluate if required:
    ```julia-repl
    julia> mutable struct F5187; x; end

    julia> (f::F5187)(x) = (f.x += x)

    julia> (iseven.(1:4) .|| (F5187(0)).(ones(4)))
    4-element Vector{Real}:
        1.0
     true
        2.0
     true
    ```

Co-authored-by: Simeon Schaub <simeondavidschaub99@gmail.com>
  • Loading branch information
2 people authored and ElOceanografo committed May 4, 2021
1 parent 8c6d188 commit 4469d7e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 14 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ New language features
* `(; a, b) = x` can now be used to destructure properties `a` and `b` of `x`. This syntax is equivalent to `a = getproperty(x, :a)`
and similarly for `b`. ([#39285])
* Implicit multiplication by juxtaposition is now allowed for radical symbols (e.g., `x√y` and `x∛y`). ([#40173])
* The short-circuiting operators `&&` and `||` can now be dotted to participate in broadcast fusion
as `.&&` and `.||`. ([#39594])

Language changes
----------------
Expand Down
29 changes: 19 additions & 10 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction, andand, oror

## Computing the result's axes: deprecated name
const broadcast_axes = axes
Expand Down Expand Up @@ -179,6 +179,21 @@ function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Arg
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
end

struct AndAnd end
andand = AndAnd()
broadcasted(::AndAnd, a, b) = broadcasted((a, b) -> a && b, a, b)
function broadcasted(::AndAnd, a, bc::Broadcasted)
bcf = flatten(bc)
broadcasted((a, args...) -> a && bcf.f(args...), a, bcf.args...)
end
struct OrOr end
const oror = OrOr()
broadcasted(::OrOr, a, b) = broadcasted((a, b) -> a || b, a, b)
function broadcasted(::OrOr, a, bc::Broadcasted)
bcf = flatten(bc)
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
end

Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)

Expand Down Expand Up @@ -1257,15 +1272,9 @@ function __dot__(x::Expr)
tmp = x.head === :(<:) ? :.<: : :.>:
Expr(:call, tmp, dotargs...)
else
if x.head === :&& || x.head === :||
error("""
Using `&&` and `||` is disallowed in `@.` expressions.
Use `&` or `|` for elementwise logical operations.
""")
end
head = string(x.head)
if last(head) == '=' && first(head) != '.'
Expr(Symbol('.',head), dotargs...)
head = String(x.head)::String
if last(head) == '=' && first(head) != '.' || head == "&&" || head == "||"
Expr(Symbol('.', head), dotargs...)
else
Expr(x.head, dotargs...)
end
Expand Down
8 changes: 4 additions & 4 deletions src/julia-parser.scm
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
(define prec-pair (add-dots '(=>)))
(define prec-conditional '(?))
(define prec-arrow (add-dots '(← → ↔ ↚ ↛ ↞ ↠ ↢ ↣ ↦ ↤ ↮ ⇎ ⇍ ⇏ ⇐ ⇒ ⇔ ⇴ ⇶ ⇷ ⇸ ⇹ ⇺ ⇻ ⇼ ⇽ ⇾ ⇿ ⟵ ⟶ ⟷ ⟹ ⟺ ⟻ ⟼ ⟽ ⟾ ⟿ ⤀ ⤁ ⤂ ⤃ ⤄ ⤅ ⤆ ⤇ ⤌ ⤍ ⤎ ⤏ ⤐ ⤑ ⤔ ⤕ ⤖ ⤗ ⤘ ⤝ ⤞ ⤟ ⤠ ⥄ ⥅ ⥆ ⥇ ⥈ ⥊ ⥋ ⥎ ⥐ ⥒ ⥓ ⥖ ⥗ ⥚ ⥛ ⥞ ⥟ ⥢ ⥤ ⥦ ⥧ ⥨ ⥩ ⥪ ⥫ ⥬ ⥭ ⥰ ⧴ ⬱ ⬰ ⬲ ⬳ ⬴ ⬵ ⬶ ⬷ ⬸ ⬹ ⬺ ⬻ ⬼ ⬽ ⬾ ⬿ ⭀ ⭁ ⭂ ⭃ ⭄ ⭇ ⭈ ⭉ ⭊ ⭋ ⭌ ← → ⇜ ⇝ ↜ ↝ ↩ ↪ ↫ ↬ ↼ ↽ ⇀ ⇁ ⇄ ⇆ ⇇ ⇉ ⇋ ⇌ ⇚ ⇛ ⇠ ⇢ ↷ ↶ ↺ ↻ --> <-- <-->)))
(define prec-lazy-or '(|\|\||))
(define prec-lazy-and '(&&))
(define prec-lazy-or (add-dots '(|\|\||)))
(define prec-lazy-and (add-dots '(&&)))
(define prec-comparison
(append! '(in isa)
(add-dots '(> < >= ≥ <= ≤ == === ≡ != ≠ !== ≢ ∈ ∉ ∋ ∌ ⊆ ⊈ ⊂ ⊄ ⊊ ∝ ∊ ∍ ∥ ∦ ∷ ∺ ∻ ∽ ∾ ≁ ≃ ≂ ≄ ≅ ≆ ≇ ≈ ≉ ≊ ≋ ≌ ≍ ≎ ≐ ≑ ≒ ≓ ≖ ≗ ≘ ≙ ≚ ≛ ≜ ≝ ≞ ≟ ≣ ≦ ≧ ≨ ≩ ≪ ≫ ≬ ≭ ≮ ≯ ≰ ≱ ≲ ≳ ≴ ≵ ≶ ≷ ≸ ≹ ≺ ≻ ≼ ≽ ≾ ≿ ⊀ ⊁ ⊃ ⊅ ⊇ ⊉ ⊋ ⊏ ⊐ ⊑ ⊒ ⊜ ⊩ ⊬ ⊮ ⊰ ⊱ ⊲ ⊳ ⊴ ⊵ ⊶ ⊷ ⋍ ⋐ ⋑ ⋕ ⋖ ⋗ ⋘ ⋙ ⋚ ⋛ ⋜ ⋝ ⋞ ⋟ ⋠ ⋡ ⋢ ⋣ ⋤ ⋥ ⋦ ⋧ ⋨ ⋩ ⋪ ⋫ ⋬ ⋭ ⋲ ⋳ ⋴ ⋵ ⋶ ⋷ ⋸ ⋹ ⋺ ⋻ ⋼ ⋽ ⋾ ⋿ ⟈ ⟉ ⟒ ⦷ ⧀ ⧁ ⧡ ⧣ ⧤ ⧥ ⩦ ⩧ ⩪ ⩫ ⩬ ⩭ ⩮ ⩯ ⩰ ⩱ ⩲ ⩳ ⩵ ⩶ ⩷ ⩸ ⩹ ⩺ ⩻ ⩼ ⩽ ⩾ ⩿ ⪀ ⪁ ⪂ ⪃ ⪄ ⪅ ⪆ ⪇ ⪈ ⪉ ⪊ ⪋ ⪌ ⪍ ⪎ ⪏ ⪐ ⪑ ⪒ ⪓ ⪔ ⪕ ⪖ ⪗ ⪘ ⪙ ⪚ ⪛ ⪜ ⪝ ⪞ ⪟ ⪠ ⪡ ⪢ ⪣ ⪤ ⪥ ⪦ ⪧ ⪨ ⪩ ⪪ ⪫ ⪬ ⪭ ⪮ ⪯ ⪰ ⪱ ⪲ ⪳ ⪴ ⪵ ⪶ ⪷ ⪸ ⪹ ⪺ ⪻ ⪼ ⪽ ⪾ ⪿ ⫀ ⫁ ⫂ ⫃ ⫄ ⫅ ⫆ ⫇ ⫈ ⫉ ⫊ ⫋ ⫌ ⫍ ⫎ ⫏ ⫐ ⫑ ⫒ ⫓ ⫔ ⫕ ⫖ ⫗ ⫘ ⫙ ⫷ ⫸ ⫹ ⫺ ⊢ ⊣ ⟂ <: >:))))
Expand Down Expand Up @@ -111,8 +111,8 @@

; operators that are special forms, not function names
(define syntactic-operators
(append! (add-dots '(= += -= *= /= //= |\\=| ^= ÷= %= <<= >>= >>>= |\|=| &= ⊻=))
'(:= $= && |\|\|| |.| ... ->)))
(append! (add-dots '(&& |\|\|| = += -= *= /= //= |\\=| ^= ÷= %= <<= >>= >>>= |\|=| &= ⊻=))
'(:= $= |.| ... ->)))
(define syntactic-unary-operators '($ & |::|))

(define syntactic-op? (Set syntactic-operators))
Expand Down
9 changes: 9 additions & 0 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,10 @@
e))))
((and (pair? e) (eq? (car e) 'comparison))
(dot-to-fuse (expand-compare-chain (cdr e)) top))
((and (pair? e) (eq? (car e) '.&&))
(make-fuse '(top andand) (cdr e)))
((and (pair? e) (eq? (car e) '|.\|\||))
(make-fuse '(top oror) (cdr e)))
(else e)))
(let ((e (dot-to-fuse rhs #t)) ; an expression '(fuse func args) if expr is a dot call
(lhs-view (ref-to-view lhs))) ; x[...] expressions on lhs turn in to view(x, ...) to update x in-place
Expand Down Expand Up @@ -2125,6 +2129,11 @@
;; e = (|.| f x)
(expand-fuse-broadcast '() e)))

'.&&
(lambda (e) (expand-fuse-broadcast '() e))
'|.\|\||
(lambda (e) (expand-fuse-broadcast '() e))

'.=
(lambda (e)
(expand-fuse-broadcast (cadr e) (caddr e)))
Expand Down
29 changes: 29 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,35 @@ p0 = copy(p)
@test repr(.!) == "Base.Broadcast.BroadcastFunction(!)"
@test eval(:(.+)) == Base.BroadcastFunction(+)

@testset "Issue #5187: Broadcasting of short-circuiting ops" begin
ex = Meta.parse("A .< 1 .|| A .> 2")
@test ex == :((A .< 1) .|| (A .> 2))
@test ex.head == :.||
ex = Meta.parse("A .< 1 .&& A .> 2")
@test ex == :((A .< 1) .&& (A .> 2))
@test ex.head == :.&&

A = -1:4
@test (A .< 1 .|| A .> 2) == [true, true, false, false, true, true]
@test (A .>= 1 .&& A .<= 2) == [false, false, true, true, false, false]

mutable struct F5187; x; end
(f::F5187)(x) = (f.x += x)
@test (iseven.(1:4) .&& (F5187(0)).(ones(4))) == [false, 1, false, 2]
@test (iseven.(1:4) .|| (F5187(0)).(ones(4))) == [1, true, 2, true]
r = 1:4; o = ones(4); f = F5187(0);
@test (@. iseven(r) && f(o)) == [false, 1, false, 2]
@test (@. iseven(r) || f(o)) == [3, true, 4, true]

@test (iseven.(1:8) .&& iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
@test (iseven.(1:8) .|| iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
r = 1:8; o = ones(8); f1 = F5187(0); f2 = F5187(0)
@test (@. iseven(r) && iseven(f1(o)) && f2(o)) == [false,false,false,1,false,false,false,2]
@test (@. iseven(r) || iseven(f1(o)) || f2(o)) == [3,true,true,true,4,true,true,true]
@test (iseven.(1:8) .&& iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
@test (iseven.(1:8) .|| iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
end

@testset "Issue #28382: inferrability of broadcast with Union eltype" begin
@test isequal([1, 2] .+ [3.0, missing], [4.0, missing])
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Expand Down

0 comments on commit 4469d7e

Please sign in to comment.