Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Eager mode #491

Merged
merged 49 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2d1eda8
Start work on eager mode
malmaud Jan 20, 2019
b23ff94
Trying to get 'execute' working.
malmaud Jan 22, 2019
4196715
Test working
malmaud Jan 30, 2019
e8e0dc1
Attribute setters
malmaud Jan 30, 2019
104097a
Variable output
malmaud Jan 30, 2019
de7bafc
Tape AD
malmaud Feb 20, 2019
19161f7
fix tape
malmaud Feb 20, 2019
c16ff7a
Start generating eager ops
malmaud Feb 20, 2019
c8ba0e3
Test working again
malmaud Feb 20, 2019
b55ecd6
Improvements to importing
malmaud Feb 20, 2019
8cadecf
Switch to dispatch
malmaud Feb 20, 2019
43393b8
casting
malmaud Feb 20, 2019
fc3dfb9
import more ops
malmaud Feb 21, 2019
a152ca4
Started gradients
malmaud Feb 22, 2019
21bddb0
Misc.
malmaud Feb 22, 2019
7d1d071
Diffeq functionality
malmaud Feb 25, 2019
50d0659
Keras demo
malmaud Feb 25, 2019
42ce802
Better grads
malmaud Feb 25, 2019
f95012d
Improved Keras
malmaud Feb 25, 2019
06b5bcc
Rename Relu layer
malmaud Feb 25, 2019
d9f7f5e
Eager summaries
malmaud Feb 25, 2019
53d029d
Summary tweaks
malmaud Feb 25, 2019
760601f
tweaks
malmaud Feb 25, 2019
5780f55
scalar summary macro
malmaud Feb 26, 2019
edec9a4
Switch to stdlib crc
malmaud Feb 26, 2019
b42dc58
Switch to context system.
malmaud Feb 26, 2019
c42d142
Move tape to context system
malmaud Feb 27, 2019
323b6a3
Clear tape
malmaud Feb 27, 2019
cc1dd39
Add CRC32 dep
malmaud Mar 4, 2019
a555ff8
Disable eager by default
malmaud Mar 4, 2019
1c12f64
PyCall adjustments etc.
malmaud Mar 14, 2019
3051907
Overhault context system
malmaud Mar 14, 2019
f7cbd38
Switch to dispatch for gradients
malmaud Mar 14, 2019
e05997b
Misc improvements
malmaud Mar 15, 2019
f6ef4c7
Apply suggestions from code review
oxinabox Mar 15, 2019
fb36d47
Rename TensorHandle->EagerTensor
malmaud Mar 15, 2019
3f6e44e
Add 'reverse' with dims kw argument
malmaud Mar 15, 2019
edac1b6
call_args style
malmaud Mar 15, 2019
81dea7d
Better fields in Sequential
malmaud Mar 15, 2019
d4fae80
eliminate some stubs
malmaud Mar 15, 2019
492d3bf
remove accidental scratch
malmaud Mar 15, 2019
5fda547
remove neural ode example until we get it robustly functional
malmaud Mar 15, 2019
8cd906f
Export symbols and change tape.
malmaud Mar 15, 2019
8c39cf5
Bump tf version
malmaud Mar 15, 2019
71e668e
Fix some tests
malmaud Mar 15, 2019
ef6b3ab
Move @op
malmaud Mar 16, 2019
d4a8d5f
Tests working
malmaud Mar 16, 2019
f2037bd
Downgrade conda version
malmaud Mar 16, 2019
0492959
Change isfile to ispath in summary_writer
malmaud Mar 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ version = "0.12.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Expand Down
8 changes: 8 additions & 0 deletions examples/diffeq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using DifferentialEquations

f(u,p,t)=1.01 .* u

u0=constant(0.5)
tspan=(0.0,1.0)
prob=ODEProblem(f, u0, tspan)
s=solve(prob)
12 changes: 12 additions & 0 deletions examples/keras.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using TensorFlow
tf=TensorFlow
m = tf.Sequential()

tf.add(m, tf.Dense(3,10))
tf.add(m, tf.Relu())
tf.add(m, tf.Dense(10, 3))

x=constant(randn(5,3))
y=3x+5
tf.compile(m, optimizer=tf.SGD(lr=1e-4), loss=tf.mse)
tf.fit(m, x, y, n_epochs=200)
8 changes: 8 additions & 0 deletions examples/neural_ode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using TensorFlow
using DifferentialEquations

model = tf.Sequential([tf.Dense(2, 1)])
f(u, p, t) = model(u)
problem = ODEProblem(f, u0=[0.5, 0.5], tspan=(0.0, 1.0))
tf.compile(model, optimizer=tf.Adam(), loss=tf.diffeq_loss(problem, t=[0.0, 0.5, 1.0]))
tf.fit(m, [1.0, 2.0, 5.0], n_epochs=100)
18 changes: 18 additions & 0 deletions src/TensorFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,23 @@ function deallocator(data, len, arg)

end

struct Context
attrs::Dict
malmaud marked this conversation as resolved.
Show resolved Hide resolved
end

Context() = Context(Dict())

struct ContextStack
contexts::Vector{Context}
end

ContextStack() = ContextStack(Context[])

global_context = ContextStack()
malmaud marked this conversation as resolved.
Show resolved Hide resolved

function __init__()
c_deallocator[] = @cfunction(deallocator, Cvoid, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}))
push!(global_context, default_context())
end

function load_python_process(;force_reload=false)
Expand Down Expand Up @@ -198,6 +213,7 @@ include("meta.jl")
include("constants.jl")
include("tensorflow_protos.jl")
include("core.jl")
include("eager.jl")
include("run.jl")
include("version.jl")
include("ops.jl")
Expand All @@ -211,5 +227,7 @@ include("summary.jl")
include("deprecated.jl")
include("show.jl")
include("generate_ops.jl")
include("tape.jl")
include("keras.jl")

end
13 changes: 11 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,10 @@ mutable struct DeviceList
end
this
end

function DeviceList(ptr, count)
malmaud marked this conversation as resolved.
Show resolved Hide resolved
new(ptr, count)
end
end

struct DeviceInfo
Expand Down Expand Up @@ -663,6 +667,8 @@ RawTensor(data::AbstractArray) = RawTensor(collect(data))

RawTensor(t::RawTensor) = t

Base.unsafe_convert(::Type{Ptr{Cvoid}}, t::RawTensor) = t.ptr

function varint_encode(b::IO, n::Integer)
while n ≥ 2^7
write(b, UInt8(0b10000000 | (n & 0b1111111)))
Expand Down Expand Up @@ -803,7 +809,7 @@ function Base.sizeof(t::RawTensor)
@tfcall(:TF_TensorByteSize, Csize_t, (Ptr{Cvoid},), t.ptr) |> Int
end

function set_device(node_desc, device::String)
function set_device(node_desc, device)
@tfcall(:TF_SetDevice, Cvoid,
(Ptr{Cvoid}, Cstring),
node_desc.ptr, device)
Expand Down Expand Up @@ -1168,7 +1174,7 @@ function load_proto(value::tensorflow.AttrValue)
load_proto(value.list)
elseif has_field(value, :_type)
type_ = value._type
proto_type_map[type_]
get(proto_type_map, type_, Float32) # wrong
malmaud marked this conversation as resolved.
Show resolved Hide resolved
end
end

Expand Down Expand Up @@ -1227,6 +1233,8 @@ function Tensor(op::Operation, value_index::Int)
Tensor{get_output_type(base_tensor)}(op, value_index)
end

# Tensor constructors
malmaud marked this conversation as resolved.
Show resolved Hide resolved

Tensor(op::Operation) = Tensor(op, 1)

malmaud marked this conversation as resolved.
Show resolved Hide resolved
Tensor(value) = convert(Tensor, value)
Expand All @@ -1242,6 +1250,7 @@ Base.convert(::Type{Tensor{Any}}, value::Tensor{R}) where {R} = value

Base.convert(::Type{Tensor{T}}, value) where {T} = convert(Tensor{T}, constant(value))


function operation_output_type(port::Port)
@tfcall(:TF_OperationOutputType, TF_DataType, (Port,), port)
end
Expand Down
Loading