Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Eager mode #491

Merged
merged 49 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2d1eda8
Start work on eager mode
malmaud Jan 20, 2019
b23ff94
Trying to get 'execute' working.
malmaud Jan 22, 2019
4196715
Test working
malmaud Jan 30, 2019
e8e0dc1
Attribute setters
malmaud Jan 30, 2019
104097a
Variable output
malmaud Jan 30, 2019
de7bafc
Tape AD
malmaud Feb 20, 2019
19161f7
fix tape
malmaud Feb 20, 2019
c16ff7a
Start generating eager ops
malmaud Feb 20, 2019
c8ba0e3
Test working again
malmaud Feb 20, 2019
b55ecd6
Improvements to importing
malmaud Feb 20, 2019
8cadecf
Switch to dispatch
malmaud Feb 20, 2019
43393b8
casting
malmaud Feb 20, 2019
fc3dfb9
import more ops
malmaud Feb 21, 2019
a152ca4
Started gradients
malmaud Feb 22, 2019
21bddb0
Misc.
malmaud Feb 22, 2019
7d1d071
Diffeq functionality
malmaud Feb 25, 2019
50d0659
Keras demo
malmaud Feb 25, 2019
42ce802
Better grads
malmaud Feb 25, 2019
f95012d
Improved Keras
malmaud Feb 25, 2019
06b5bcc
Rename Relu layer
malmaud Feb 25, 2019
d9f7f5e
Eager summaries
malmaud Feb 25, 2019
53d029d
Summary tweaks
malmaud Feb 25, 2019
760601f
tweaks
malmaud Feb 25, 2019
5780f55
scalar summary macro
malmaud Feb 26, 2019
edec9a4
Switch to stdlib crc
malmaud Feb 26, 2019
b42dc58
Switch to context system.
malmaud Feb 26, 2019
c42d142
Move tape to context system
malmaud Feb 27, 2019
323b6a3
Clear tape
malmaud Feb 27, 2019
cc1dd39
Add CRC32 dep
malmaud Mar 4, 2019
a555ff8
Disable eager by default
malmaud Mar 4, 2019
1c12f64
PyCall adjustments etc.
malmaud Mar 14, 2019
3051907
Overhault context system
malmaud Mar 14, 2019
f7cbd38
Switch to dispatch for gradients
malmaud Mar 14, 2019
e05997b
Misc improvements
malmaud Mar 15, 2019
f6ef4c7
Apply suggestions from code review
oxinabox Mar 15, 2019
fb36d47
Rename TensorHandle->EagerTensor
malmaud Mar 15, 2019
3f6e44e
Add 'reverse' with dims kw argument
malmaud Mar 15, 2019
edac1b6
call_args style
malmaud Mar 15, 2019
81dea7d
Better fields in Sequential
malmaud Mar 15, 2019
d4fae80
eliminate some stubs
malmaud Mar 15, 2019
492d3bf
remove accidental scratch
malmaud Mar 15, 2019
5fda547
remove neural ode example until we get it robustly functional
malmaud Mar 15, 2019
8cd906f
Export symbols and change tape.
malmaud Mar 15, 2019
8c39cf5
Bump tf version
malmaud Mar 15, 2019
71e668e
Fix some tests
malmaud Mar 15, 2019
ef6b3ab
Move @op
malmaud Mar 16, 2019
d4a8d5f
Tests working
malmaud Mar 16, 2019
f2037bd
Downgrade conda version
malmaud Mar 16, 2019
0492959
Change isfile to ispath in summary_writer
malmaud Mar 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/TensorFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,13 @@ function deallocator(data, len, arg)

end

struct Context
attrs::Dict
end

Context() = Context(Dict())

struct ContextStack
contexts::Vector{Context}
end

ContextStack() = ContextStack(Context[])

const global_context = ContextStack()
include("context.jl")

function __init__()
c_deallocator[] = @cfunction(deallocator, Cvoid, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}))
push!(global_context, default_context())
for context in default_context()
push!(global_context, context)
end
end

function load_python_process(;force_reload=false)
Expand Down
44 changes: 44 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
abstract type Context
end

struct ContextStack
contexts::Vector{Context}
end

ContextStack() = ContextStack(Context[])

function Base.push!(stack::ContextStack, context::Context)
push!(stack.contexts, context)
end

function Base.pop!(stack::ContextStack)
pop!(stack.contexts)
end

function default_context()
return [ExecutionMode(eager=false)]
end

function context_value(context_type)
return global_context[context_type]
end

function Base.getindex(c::ContextStack, context_type)
value = nothing
for context in c.contexts
if isa(context, context_type)
value = context
end
end
return value
end

function with_context(block, ctx)
push!(global_context, ctx)
res = block()
pop!(global_context)
return res
end


const global_context = ContextStack()
56 changes: 14 additions & 42 deletions src/eager.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mutable struct EagerContext
mutable struct EagerContext <: Context
ptr::Ptr{Cvoid}
end

Expand Down Expand Up @@ -121,9 +121,10 @@ end

function EagerOp(op_name)
if get_eager_context() === nothing
ctx = Context()
ctx.attrs["eager_context"] = EagerContext()
push!(global_context, ctx)
# ctx = Context()
# ctx.attrs["eager_context"] = EagerContext()
# push!(global_context, ctx)
push!(global_context, EagerContext())
end
ctx = get_eager_context()
status = Status()
Expand Down Expand Up @@ -245,7 +246,6 @@ function clear_caches(ctx::EagerContext)
@tfcall(:TFE_ContextClearCaches, Cvoid, (Ptr{Cvoid},), ctx)
end


function num_dims(h::TensorHandle)
status = Status()
res = @tfcall(:TFE_TensorHandleNumDims, Cint, (Ptr{Cvoid}, Ptr{Cvoid}), h, status)
Expand Down Expand Up @@ -335,52 +335,24 @@ function inplace_sub(x, y)
Ops.inplace_sub(x, i, y)
end

function Base.push!(stack::ContextStack, context::Context)
push!(stack.contexts, context)
end

function Base.pop!(stack::ContextStack)
pop!(stack.contexts)
struct ExecutionMode <: Context
eager::Bool
end

function default_context()
context = Context()
context.attrs["eager"] = false
return context
end
ExecutionMode(;eager=true) = ExecutionMode(eager)

function enable_eager_execution()
context = Context()
context.attrs["eager"] = true
push!(global_context, context)
# context = Context()
# context.attrs["eager"] = true
# push!(global_context, context)
push!(global_context, ExecutionMode(eager=true))
return nothing
end

function Base.getindex(c::ContextStack, name)
value = nothing
for context in c.contexts
if name in keys(context.attrs)
value = context.attrs[name]
end
end
return value
end

function context_value(name)
return global_context[name]
end

function in_eager_mode()
return context_value("eager")::Bool
end

function with_context(block, ctx)
push!(global_context, ctx)
res = block()
pop!(global_context)
return res
return context_value(ExecutionMode).eager
end

function get_eager_context()
return context_value("eager_context")
return context_value(EagerContext)
end
malmaud marked this conversation as resolved.
Show resolved Hide resolved
13 changes: 4 additions & 9 deletions src/summary_writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import ..TensorFlow
const tf = TensorFlow
import ..TensorFlow: tensorflow, Graph, get_def_graph, @py_proc

struct FileWriter
struct FileWriter <: tf.Context
malmaud marked this conversation as resolved.
Show resolved Hide resolved
file_handle
logdir::String
end
Expand Down Expand Up @@ -100,19 +100,15 @@ function Base.write(writer::FileWriter, graph::Graph)
end

function set_default(writer::FileWriter)
context = tf.Context()
context.attrs["default_file_writer"] = writer
push!(tf.global_context, context)
push!(tf.global_context, writer)
end

function with_default(writer::FileWriter, block)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
context = tf.Context()
context.attrs["default_file_writer"] = writer
tf.with_context(block, context)
tf.with_context(block, writer)
end

function get_default_file_writer()
return tf.global_context["default_file_writer"]
return tf.context_value(FileWriter)
end

function record_summary(summary_pb; step=0)
Expand All @@ -121,7 +117,6 @@ function record_summary(summary_pb; step=0)
write(writer, summary_pb, step)
end


function Base.close(writer::FileWriter)
close(writer.file_handle)
nothing
Expand Down
21 changes: 14 additions & 7 deletions src/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,26 @@ end

Tape(;kwargs...) = Tape(Dict{TensorHandle, TapeNode}(), Dict(kwargs...))

struct TapeContext <: Context
tape::Union{Tape, Nothing}
end

function set_tape(new_tape=nothing)
if new_tape === nothing
new_tape = Tape()
end
context = Context()
context.attrs["tape"] = new_tape
push!(global_context, context)
push!(global_context, TapeContext(new_tape))
return new_tape
end

get_tape() = global_context["tape"]
function get_tape()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can and thus should be a 1 line function

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not applicable anymore since it's no longer a single expression. although if there's a way to write it more compactly, then we should do it.

tape_context = context_value(TapeContext)
if tape_context === nothing
return nothing
else
return tape_context.tape
end
end

function add_node(t, node)
tape = get_tape()
Expand Down Expand Up @@ -128,9 +137,7 @@ end)
end)

function with_no_grad(f)
context = Context()
context.attrs["tape"] = nothing
res = with_context(f, context)
res = with_context(f, TapeContext(nothing))
return res
end

Expand Down