Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse mode apply iterate #1485

Merged
merged 25 commits into from
Jun 8, 2024
Merged

Reverse mode apply iterate #1485

merged 25 commits into from
Jun 8, 2024

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented May 30, 2024

No description provided.

@wsmoses
Copy link
Member Author

wsmoses commented Jun 4, 2024

wmoses@beast:~/git/Enzyme.jl ((HEAD detached at origin/reviterate)) $ cat batch.jl 
using Enzyme

Enzyme.API.printall!(true)

concat() = ()
concat(a) = a
concat(a, b) = (a..., b...)
concat(a, b, c...) = concat(concat(a, b), c...)


function make_byref(out, x)
	res = 0.0
    x = Base.inferencebarrier(@inbounds (x[1][1],x[1][2]))
	for v in x
		v = v::Float64
		res += v*v
	end
	out[] = res
	nothing
end


function make_byref2(out, x)
	res = 0.0
    x = Base.inferencebarrier(@inbounds (x[1][1],x[1][2]))
    tup = iterate(x)
    if tup !== nothing
        res += tup[1]::Float64
    end
	out[] = res
	nothing
end

    x = [(2.0, 3.0), (7.9, 11.2)]
    dx = [(0.0, 0.0), (0.0, 0.0)]
    dx2 = [(0.0, 0.0), (0.0, 0.0)]
    out = Ref(0.0)
    dout = Ref(1.0)
    dout2 = Ref(3.0)
    res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), BatchDuplicated(x, (dx, dx2)))

@gbaraldi

Copy link
Contributor

github-actions bot commented Jun 4, 2024

Benchmark Results

main 18c6fa1... main/18c6fa18c064e6...
basics/overhead 4.03 ± 0.001 ns 4.34 ± 0.01 ns 0.929
time_to_load 0.35 ± 0.0015 s 0.376 ± 0.0013 s 0.932

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@wsmoses
Copy link
Member Author

wsmoses commented Jun 5, 2024

@vchuravy per signal convo, my best memory of where this left off (@gbaraldi please correct me).

We find a segfault. Specifically we are apparently setting the name of a datatype to something new which is wrong (specifically tuple.name or something).

Backtracing to the culprit we find that it is getting overwritten in the += of the reverse pass. Specifically we have a for loop of apply generics (for the iterate) and also an unstable get nth index of. These result in an allocation of shadow pointers from their results for use in the reverse pass. We shuold have a cache of Tuple{Ref{Float64},Ref{Float64}}. The tuple of size 2 is from the batch. The ref as that is what the shadow from the augmented primal of unstable get nth index of should return.

In the reverse pass, we should have two iterations -- just like the fwd pass has two iterations. It is on the last iteration of the reverse pass [aka the theoretical index 0] which is the issue. The cache does not actually contain 2 (for the two iters of loop) Tuple{Ref{Float64},Ref{Float64}}'s. Instead it contains as first element (1=Tuple.name, 2=Tuple.name) and then (1=RefValue{1.0}, 2=RefValue{1.0}) [random float values added, idr what they were]. We had earlier spent a rabbit hole looking at the assembly of the indexing since LLVM appears to have split the induction variable into two counters, one which is negative and increments up and the other which goes down. One is used for indexing, the other for the revere pass loop exit. This was originally suspicious since the indexing one on the last reverse pass iteration (aka i==0) seemed to index at an offset of -8 possibly implying the loop was doing an incorrect iteration. However this negation was actually expected as it basically was LLVM strength reduction moving part of the indexing out of the loop.

So the fundamental issue here is why is the 0th element of the cache busted. So far, no clue.

It is possible this could be simplified a bit from julia end by basically turning it into explicit iterate calls and then doing some more precise type stability -- but maybe not.

For obvious reasons of unknown loop size we use the exponential allocation which makes things more obnoxious to look at per the generated junk IR. However, it may be a bug in the julia side of the expontential copy? Just a guess though as we ended very very stuck.

@vchuravy
Copy link
Member

vchuravy commented Jun 5, 2024

We crash due a data corruption:

$2 = (jl_datatype_t *) 0x7ac98c0950d0
(rr) p jdt->name
$3 = (jl_typename_t *) 0x7ac9a1df6ae0 <jl_system_image_data+72308000>
(rr) p jdt->name->name
$4 = (jl_sym_t *) 0x4030000000000000
(rr) p jdt->name->name

Reverse executing

Thread 1 hit Hardware watchpoint 1: *(jl_sym_t **) 0x7ac9a1df6ae0

Old value = (jl_sym_t *) 0x4030000000000000
New value = (jl_sym_t *) 0x4010000000000000
0x00007ac9b308e9a5 in * () at float.jl:411
warning: 411	float.jl: No such file or directory
(rr) bt
#0  0x00007ac9b308e9a5 in * () at float.jl:411
#1  julia_make_byref_1412 (out=..., x=<optimized out>)
    at /home/vchuravy/src/Enzyme/rev_apply_iterate.jl:14
#2  0x00007ac9b308e9a5 in diffe2julia_make_byref_1412wrap ()
#3  0x00007ac9b30990d8 in macro expansion ()
    at /home/vchuravy/src/Enzyme/src/compiler.jl:5916
#4  enzyme_call () at /home/vchuravy/src/Enzyme/src/compiler.jl:5566

Second overwrite:

Thread 1 hit Hardware watchpoint 1: *(jl_sym_t **) 0x7ac9a1df6ae0

Old value = (jl_sym_t *) 0x4010000000000000
New value = (jl_sym_t *) 0x7ac9ab47c1d8
0x00007ac9b308e997 in * () at float.jl:411
411	in float.jl

Old value = (jl_sym_t *) 0x4010000000000000
New value = (jl_sym_t *) 0x7ac9ab47c1d8
0x00007ac9b308e997 in * () at float.jl:411
411	in float.jl
(rr) bt 4
#0  0x00007ac9b308e997 in * () at float.jl:411
#1  julia_make_byref_1412 (out=..., x=<optimized out>)
    at /home/vchuravy/src/Enzyme/rev_apply_iterate.jl:14
#2  0x00007ac9b308e997 in diffe2julia_make_byref_1412wrap ()
#3  0x00007ac9b30990d8 in macro expansion ()
    at /home/vchuravy/src/Enzyme/src/compiler.jl:5916
(More stack frames follow...)

Since this memory is allocated in the system image (it's the name of Tuple) it smells more like a calling mistake.

(rr) p jl_(jt)
Tuple{Base.RefValue{DataType}}
$1 = void

@vchuravy
Copy link
Member

vchuravy commented Jun 5, 2024

Instead it contains as first element (1=Tuple.name, 2=Tuple.name)

So how did these values end up in the cache?

@vchuravy
Copy link
Member

vchuravy commented Jun 5, 2024

So I find:

#12 0x00007ac9b444a212 in ijl_apply_generic (
    F=0x7ac98ec33350 <jl_system_image_data+1456912>, args=0x7fffbe9ae1d8,
    nargs=5) at /home/vchuravy/src/julia-1.10/src/gf.c:3077
3077	    return _jl_invoke(F, args, nargs, mfunc, world);
(rr) p jl_(F)
Enzyme.Compiler.idx_jl_getfield_rev
$7 = void
(rr) p jl_(args[0])
Base.RefValue{DataType}(x=Tuple{Float64, Int64})
$8 = void
(rr) p jl_(args[1])
(1=Any, 2=Any)
$9 = void
(rr) p jl_(args[2])
Base.Val{1}
$10 = void
(rr) p jl_(args[3])
Base.Val{false}()
$11 = void
(rr) p jl_(args[4])
Base.RefValue{DataType}(x=Tuple{Float64, Int64})
$12 = void

A bit fishy. The code got changed here but args[0] should be a Val

@wsmoses
Copy link
Member Author

wsmoses commented Jun 5, 2024

@vchuravy I don't think so?

The calling conv of that is
julia function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst}

which dptr (and dptrs for extra batches) are the shadows, not vals?

@wsmoses
Copy link
Member Author

wsmoses commented Jun 5, 2024

   if !is_constant_value(gutils, ops[1])
        inp = invert_pointer(gutils, ops[1], B)
        inp = lookup_value(gutils, inp, B)
        if width == 1
            inps = [inp]
        else
            inps = LLVM.Value[]
            for w in 1:width
                push!(inps, extract_value!(B, inp, w-1))
            end
        end
    else
        inp = new_from_original(gutils, ops[1])
        inp = lookup_value(gutils, inp, B)
        inps = [inp]
    end

    vals = LLVM.Value[]
    push!(vals, inps[1])

    push!(vals, tape)

    sym = new_from_original(gutils, ops[2])
    sym = lookup_value(gutils, sym, B)
    sym = (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, sym)
    sym = emit_apply_type!(B, Base.Val, [sym])
    push!(vals, sym)

    push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig))))

    for v in inps[2:end]
        push!(vals, v)
    end

    pushfirst!(vals, unsafe_to_llvm(idx_jl_getfield_rev))

so the unction, original shadow, tape [aka dret], type{val{sym}}, batched other shadows [from 2 ... n]

the shadow tape does look ishy tho.

@wsmoses
Copy link
Member Author

wsmoses commented Jun 5, 2024

@vchuravy okay the caching mechanism is fine, the function is actually returning garbage.

"Result of call to idx_jl_getfield_aug"
(1=Tuple.name, 2=Tuple.name)
Tuple.name
Tuple.name
"Val(AnyArray)="

@wsmoses
Copy link
Member Author

wsmoses commented Jun 5, 2024

The input to idx_jl_getfield_aug is RefValue{DataType}(Tuple{Float64, Int64})

@codecov-commenter
Copy link

codecov-commenter commented Jun 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.35%. Comparing base (ad7694e) to head (9499ee4).

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1485       +/-   ##
===========================================
+ Coverage   71.15%   96.35%   +25.19%     
===========================================
  Files          30        9       -21     
  Lines       11224      411    -10813     
===========================================
- Hits         7986      396     -7590     
+ Misses       3238       15     -3223     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wsmoses wsmoses closed this Jun 8, 2024
@wsmoses wsmoses reopened this Jun 8, 2024
@wsmoses wsmoses closed this Jun 8, 2024
@wsmoses wsmoses reopened this Jun 8, 2024
@wsmoses wsmoses merged commit 86da3cd into main Jun 8, 2024
47 of 145 checks passed
@wsmoses wsmoses deleted the reviterate branch June 8, 2024 21:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants