Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Eager mode #491

Merged
merged 49 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2d1eda8
Start work on eager mode
malmaud Jan 20, 2019
b23ff94
Trying to get 'execute' working.
malmaud Jan 22, 2019
4196715
Test working
malmaud Jan 30, 2019
e8e0dc1
Attribute setters
malmaud Jan 30, 2019
104097a
Variable output
malmaud Jan 30, 2019
de7bafc
Tape AD
malmaud Feb 20, 2019
19161f7
fix tape
malmaud Feb 20, 2019
c16ff7a
Start generating eager ops
malmaud Feb 20, 2019
c8ba0e3
Test working again
malmaud Feb 20, 2019
b55ecd6
Improvements to importing
malmaud Feb 20, 2019
8cadecf
Switch to dispatch
malmaud Feb 20, 2019
43393b8
casting
malmaud Feb 20, 2019
fc3dfb9
import more ops
malmaud Feb 21, 2019
a152ca4
Started gradients
malmaud Feb 22, 2019
21bddb0
Misc.
malmaud Feb 22, 2019
7d1d071
Diffeq functionality
malmaud Feb 25, 2019
50d0659
Keras demo
malmaud Feb 25, 2019
42ce802
Better grads
malmaud Feb 25, 2019
f95012d
Improved Keras
malmaud Feb 25, 2019
06b5bcc
Rename Relu layer
malmaud Feb 25, 2019
d9f7f5e
Eager summaries
malmaud Feb 25, 2019
53d029d
Summary tweaks
malmaud Feb 25, 2019
760601f
tweaks
malmaud Feb 25, 2019
5780f55
scalar summary macro
malmaud Feb 26, 2019
edec9a4
Switch to stdlib crc
malmaud Feb 26, 2019
b42dc58
Switch to context system.
malmaud Feb 26, 2019
c42d142
Move tape to context system
malmaud Feb 27, 2019
323b6a3
Clear tape
malmaud Feb 27, 2019
cc1dd39
Add CRC32 dep
malmaud Mar 4, 2019
a555ff8
Disable eager by default
malmaud Mar 4, 2019
1c12f64
PyCall adjustments etc.
malmaud Mar 14, 2019
3051907
Overhault context system
malmaud Mar 14, 2019
f7cbd38
Switch to dispatch for gradients
malmaud Mar 14, 2019
e05997b
Misc improvements
malmaud Mar 15, 2019
f6ef4c7
Apply suggestions from code review
oxinabox Mar 15, 2019
fb36d47
Rename TensorHandle->EagerTensor
malmaud Mar 15, 2019
3f6e44e
Add 'reverse' with dims kw argument
malmaud Mar 15, 2019
edac1b6
call_args style
malmaud Mar 15, 2019
81dea7d
Better fields in Sequential
malmaud Mar 15, 2019
d4fae80
eliminate some stubs
malmaud Mar 15, 2019
492d3bf
remove accidental scratch
malmaud Mar 15, 2019
5fda547
remove neural ode example until we get it robustly functional
malmaud Mar 15, 2019
8cd906f
Export symbols and change tape.
malmaud Mar 15, 2019
8c39cf5
Bump tf version
malmaud Mar 15, 2019
71e668e
Fix some tests
malmaud Mar 15, 2019
ef6b3ab
Move @op
malmaud Mar 16, 2019
d4a8d5f
Tests working
malmaud Mar 16, 2019
f2037bd
Downgrade conda version
malmaud Mar 16, 2019
0492959
Change isfile to ispath in summary_writer
malmaud Mar 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/TensorFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ include("meta.jl")
include("constants.jl")
include("tensorflow_protos.jl")
include("core.jl")
include("eager.jl")
include("run.jl")
include("version.jl")
include("ops.jl")
Expand Down
6 changes: 6 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,10 @@ mutable struct DeviceList
end
this
end

function DeviceList(ptr, count)
malmaud marked this conversation as resolved.
Show resolved Hide resolved
new(ptr, count)
end
end

struct DeviceInfo
Expand Down Expand Up @@ -663,6 +667,8 @@ RawTensor(data::AbstractArray) = RawTensor(collect(data))

RawTensor(t::RawTensor) = t

Base.unsafe_convert(::Type{Ptr{Cvoid}}, t::RawTensor) = t.ptr

function varint_encode(b::IO, n::Integer)
while n ≥ 2^7
write(b, UInt8(0b10000000 | (n & 0b1111111)))
Expand Down
188 changes: 188 additions & 0 deletions src/eager.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
mutable struct EagerContext
ptr::Ptr{Cvoid}

function EagerContext()
options = @tfcall(:TFE_NewContextOptions, Ptr{Cvoid}, ())
@tfcall(:TFE_ContextOptionsSetAsync, Cvoid, (Ptr{Cvoid}, Cuchar), options, 0)
status = Status()
context = @tfcall(:TFE_NewContext, Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}), options, status)
check_status(status)
this = new(context)
finalizer(this) do self
@tfcall(:TFE_DeleteContext, Cvoid, (Ptr{Cvoid},), self.ptr)
end
@tfcall(:TFE_DeleteContextOptions, Cvoid, (Ptr{Cvoid},), options)
return this
end
end

Base.unsafe_convert(::Type{Ptr{Cvoid}}, c::EagerContext) = c.ptr

function DeviceList(ctx::EagerContext)
status = Status()
ptr = @tfcall(:TFE_ContextListDevices, Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}), ctx, status)
check_status(status)
count = @tfcall(:TF_DeviceListCount, Cint, (Ptr{Cvoid},), ptr)
this = new(ptr, count)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't call new outside of a inner constructor?

return this
end

mutable struct TensorHandle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't need to be mutable does it?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does so that it can have a finalizer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be right there.
Though IIRC you can hook the Finalizer directly to the Ptr

ptr::Ptr{Cvoid}

function TensorHandle(tensor)
status = Status()
ptr = @tfcall(:TFE_NewTensorHandle, Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}), tensor.ptr, status)
check_status(status)
this = new(ptr)
finalizer(this) do self
@tfcall(:TFE_DeleteTensorHandle, Cvoid, (Ptr{Cvoid},), self.ptr)
end
return this
end

function TensorHandle()
return new()
end
end

Base.unsafe_convert(::Type{Ptr{Cvoid}}, h::TensorHandle) = h.ptr

function async_wait(ctx::EagerContext)
status = Status()
@tfcall(:TFE_ContextAsyncWait, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), ctx, status)
check_status(status)
end

function device_name(h::TensorHandle)
status = Status()
c_name = @tfcall(:TFE_TensorHandleDeviceName, Cstring, (Ptr{Cvoid}, Ptr{Cvoid}), h, status)
check_status(status)
return unsafe_string(c_name)
end

function data_type(h::TensorHandle)
return @tfcall(:TFE_TensorHandleDataType, TF_DataType, (Ptr{Cvoid},), h) |> tf_to_jl_type
end

function resolve(h::TensorHandle)
status = Status()
ptr = @tfcall(:TFE_TensorHandleResolve, Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}), h, status)
check_status(status)
tensor = RawTensor(ptr)
return tensor
end

mutable struct EagerOp
ptr::Ptr{Cvoid}
op_name::String
end

function EagerOp(ctx::EagerContext, op_name)
status = Status()
ptr = @tfcall(:TFE_NewOp, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), ctx, op_name, status)
check_status(status)
this = EagerOp(ptr, String(op_name))
finalizer(this) do self
@tfcall(:TFE_DeleteOp, Cvoid, (Ptr{Cvoid},), self)
end
return this
end

Base.unsafe_convert(::Type{Ptr{Cvoid}}, op::EagerOp) = op.ptr

function add_input(op::EagerOp, h::TensorHandle)
status = Status()
@tfcall(:TFE_OpAddInput, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), op, h, status)
check_status(status)
return
end

function execute(op::EagerOp)
op_desc = get_op_def(op.op_name)
n_outputs = length(op_desc.output_arg)
handles = [TensorHandle() for _ in 1:n_outputs]
ptrs = [Ptr{Cvoid}(0) for _ in 1:n_outputs]
num_ret = Cint(n_outputs)
status = Status()
@tfcall(:TFE_Execute, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{Cvoid}), op, ptrs, Ref(num_ret), status)
check_status(status)
for i in 1:n_outputs
handles[i].ptr = ptrs[i]
end
return handles
end

function test_eager()
ctx = EagerContext()
h1 = TensorHandle(RawTensor([1,2]))
h2 = TensorHandle(RawTensor([3,4]))
op = EagerOp(ctx, "Add")
add_input(op, h1)
add_input(op, h2)
dtype = data_type(h1)
op["T"] = dtype
res = execute(op)
return resolve(res[1])
end

function setindex!(op::EagerOp, tensor::RawTensor, attr_name)
status = Status()
@tfcall(:TFE_OpSetAttrTensor, Cvoid, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}, Ptr{Cvoid}), op, attr_name, tensor, status)
check_status(status)
end

function setindex!(op::EagerOp, dtype::DataType, attr_name)
@tfcall(:TFE_OpSetAttrType, Cvoid, (Ptr{Cvoid}, Cstring, TF_DataType), op, attr_name, dtype|>jl_to_df_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@tfcall(:TFE_OpSetAttrType, Cvoid, (Ptr{Cvoid}, Cstring, TF_DataType), op, attr_name, dtype|>jl_to_df_type)
tf_dtype = jl_to_df_type(dtype)
@tfcall(:TFE_OpSetAttrType, Cvoid, (Ptr{Cvoid}, Cstring, TF_DataType), op, attr_name, tf_dtype)

end

function setindex!(op::EagerOp, value::Integer, attr_name)
value = Int64(value)
@tfcall(:TFE_OpSetAttrInt, Cvoid, (Ptr{Cvoid}, Cstring, Int64), op, attr_name, value)
end

function setindex!(op::EagerOp, value::Bool, attr_name)
@tfcall(:TFE_OpSetAttrBool, Cvoid, (Ptr{Cvoid}, Cstring, Cuchar), op, attr_name, value)
end

function setindex!(op::EagerOp, value::AbstractFloat, attr_name)
value = Float32(value)
@tfcall(:TFE_OpSetAttrFloat, Cvoid, (Ptr{Cvoid}, Cstring, Cfloat), op, attr_name, value)
end

function setindex!(op::EagerOp, value::AbstractString, attr_name)
value = String(value)
@tfcall(:TFE_OpSetAttrString, Cvoid, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}, Cint), op, attr_name, Vector{UInt8}(value), sizeof(value))
end

function setindex!(op::EagerOp, value::Vector, attr_name)
set_attr_list(op, attr_name, value)
end

function set_attr_list(op::EagerOp, attr_name, list::Vector{<:Integer})
list = Int64[Int64(x) for x in list]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
list = Int64[Int64(x) for x in list]
list = Int64.(list)

@tfcall(:TFE_OpSetAttrIntList, Cvoid, (Ptr{Cvoid}, Cstring, Ptr{Int64}, Cint), op, attr_name, list, length(list))
end

function set_attr_list(op::EagerOp, attr_name, list::Vector{<:AbstractFloat})
list = Float32[Float32(x) for x in list]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
list = Float32[Float32(x) for x in list]
list = Float32.(x)

@tfcall(:TFE_OpSetAttrFloatList, Cvoid, (Ptr{Cvoid}, Cstring, Ptr{Float32}, Cint), op, attr_name, list, length(list))
end

function set_attr_list(op::EagerOp, attr_name, list::Vector{<:DataType})
list = map(jl_to_df_type, list)
@tfcall(:TFE_OpSetAttrTypeList, Cvoid, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}, Cint), op, attr_name, list, length(list))
end

function set_attr_shape_list(op::EagerOp, attr_name, list::Vector)
dims = Vector{Int64}[]
for shape in list
push!(dims, Int64[shape...])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't thiis an error?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's right? At least Keno's PR 9325475 touched this code to make it what it is now.

end
@tfcall(:TFE_OpSetAttrShapeList, Cvoid, (Ptr{Cvoid}, Cstring, Ptr{Ptr{Int64}}, Ptr{Cint}, Cint),
op,
attr_name,
dims,
Cint[length(x) for x in dims],
length(dims))
end