Skip to content

Commit

Permalink
optmization: inplace-op, indexing, 0-dim Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
KDr2 committed Mar 27, 2020
1 parent 5a13dca commit c021717
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 285 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "3ba4a88b-5b67-4a96-bb04-131d22fbab27"
license = "MIT"
desc = "Julia interface of PyTorch"
repo = "https://github.com/TuringLang/ThArrays.jl.git"
version = "0.1"
version = "0.1.1-dev"

[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand Down
15 changes: 15 additions & 0 deletions csrc/torch_capi_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ torch::Tensor* tensor_from_data(
return nullptr;
}

torch::Tensor* tensor_int64_0dim(int64_t i, int grad) {
PROTECT(
return new torch::Tensor(torch::tensor(i, at::requires_grad(grad)));
);
return nullptr;
}

void tensor_destroy(torch::Tensor *tensor) {
// std::cout << "DEBUG: Tensor " << tensor << " is destroyed!\n";
if(tensor) delete tensor;
Expand Down Expand Up @@ -149,6 +156,14 @@ void tensor_method_item(torch::Tensor *t, int8_t tid, void *data) {
}
}

torch::Tensor* tensor_method_index_select_int64(
torch::Tensor *t, int64_t dim, int64_t idx) {
PROTECT(
return new torch::Tensor(t->index_select(dim, torch::tensor(idx)));
);
return nullptr;
}

void tensor_method_backward(
torch::Tensor *t, torch::Tensor *g,
bool keep_graph=false, bool create_graph=false) {
Expand Down
3 changes: 3 additions & 0 deletions csrc/torch_capi_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ extern "C" {
void *data, size_t datalen, int8_t tid,
int64_t *size_data, int64_t *strides_data, size_t dim,
int copy_data, int grad);
CAPI_DLLEXPORT torch::Tensor* tensor_int64_0dim(int64_t i, int grad);
CAPI_DLLEXPORT void tensor_destroy(torch::Tensor *tensor);
CAPI_DLLEXPORT const char* tensor_to_string(torch::Tensor *tensor);

Expand All @@ -30,6 +31,8 @@ extern "C" {

// methods on Tensor
CAPI_DLLEXPORT void tensor_method_item(torch::Tensor *t, int8_t tid, void *data);
CAPI_DLLEXPORT torch::Tensor* tensor_method_index_select_int64(
torch::Tensor *t, int64_t dim, int64_t idx);
CAPI_DLLEXPORT void tensor_method_backward(
torch::Tensor *t, torch::Tensor *g, bool keep_graph, bool create_graph);

Expand Down
2 changes: 1 addition & 1 deletion src/common-methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Base.sum(t::Tensor{T}) where T = ThC.sum(t, eltype_id(T))

Base.:+(r::TorchNumber, t::Tensor) = ThC.add1(t, r)
Base.:+(t::Tensor, r::TorchNumber) = r + t
Base.:+(a::Tensor{T, N}, b::Tensor{T, N}) where {T, N} = ThC.add(a, b)
Base.:+(a::Tensor{T, N}, b::Tensor{T, N}) where {T, N} = ThC.opt_add(a, b)

Base.:-(r::TorchNumber, t::Tensor) = ThC.ones_like(t) * r - t
Base.:-(t::Tensor, r::TorchNumber) = ThC.sub1(t, r)
Expand Down
22 changes: 15 additions & 7 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,22 @@ function Tensor(array::Array{T, N}; detach=false, requires_grad=false) where {T,
end

# 0-dim Tensor
function Tensor(s::Int64; requires_grad=false)
grad = requires_grad ? 1 : 0
ptr = ccall((:tensor_int64_0dim, :libtorch_capi),
Ptr{Cvoid},
(Clonglong, Cint), s, grad)
Tensor{Int64, 0}(ptr, nothing)
end

function Tensor(s::T; requires_grad=false) where {T <: TorchNumber}
data = T[s]
grad = requires_grad ? 1 : 0
ptr = ccall((:tensor_from_data, :libtorch_capi),
Ptr{Cvoid},
(Ptr{Cvoid}, Csize_t, Cchar,
Ptr{Clonglong}, Ptr{Clonglong}, Csize_t, Cint, Cint),
data, sizeof(T), TYPE_MAP[T], C_NULL, C_NULL, 0, 0, grad)
data, sizeof(T), TYPE_MAP[T], C_NULL, C_NULL, 0, 1, grad)
Tensor{T, 0}(ptr, nothing)
end

Expand Down Expand Up @@ -149,25 +157,25 @@ function _tensor_indices(t::Tensor, I)
collect(indices), shape
end

_to_dim_0(t::Tensor) = ThC.reshape(t, Int64[])
_to_dim_1_1(t::Tensor) = ThC.reshape(t, [1, 1])
_to_dim_0(t::Tensor) = ThC.opt_reshape(t, Int64[])
_to_dim_1_1(t::Tensor) = ThC.opt_reshape(t, [1, 1])

function Base.getindex(t::Tensor, I...)
ts, shape = _tensor_indices(t, I)
ret = t
for i in 1:length(ts)
ret = ThC.index_select(ret, i - 1, Tensor(ts[i]))
ret = ThC.opt_index_select(ret, i - 1, ts[i])
end
all(x -> x == 1, size(ret)) && shape == Union{}[] && return _to_dim_0(ret)
reshape(ret, shape)
ThC.opt_reshape(ret, shape)
end
Base.getindex(t::Tensor{T}) where T = item(t)
Base.getindex(t::Tensor, i::Int64) = t[eachindex(t)[i]]
Base.getindex(t::Tensor{T, 1}, i::Int64) where T =
ThC.index_select(t, 0, Tensor(i - 1)) |> _to_dim_0
ThC.opt_index_select(t, 0, (i - 1)) |> _to_dim_0
function Base.getindex(t::Tensor, I::UnitRange{Int64})
t = vcat(map(i->_to_dim_1_1(t[i]), eachindex(t)[I])...)
reshape(t, [length(t)])
ThC.opt_reshape(t, [length(t)])
end

function Base.setindex!(t::Tensor{T}, v::Tensor{T}, I...) where T
Expand Down
8 changes: 6 additions & 2 deletions src/thc/thc-generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ function ccall_julia_args(f::APIFunction)
end

function return_statement(f::APIFunction)
if f.return_type == "void" && f.args[1].first == "out__"
if match(r"_\d*$", f.func_name) != nothing
return " return self"
elseif f.return_type == "void" && f.args[1].first == "out__"
lines = []
for i in 1:f.output_count
push!(lines,
Expand Down Expand Up @@ -258,7 +260,9 @@ function main()
count += 1
end

write(output, "\n\n")
write(output, "\n")
write(output, "include(\"thc-opt.jl\")\n")
write(output, "\n")
write(output, "end\n") # module end

close(output)
Expand Down
35 changes: 35 additions & 0 deletions src/thc/thc-opt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
function opt_add(self::Tensor{T, N}, other::Tensor{T, N}) where {T, N}
outputs__ = Int[0]
__cret = ccall((:atg_add, :libtorch_capi),
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
outputs__, self.pointer, other.pointer)
return Tensor{T, N}(Ptr{Cvoid}(outputs__[1]), nothing)
end

function opt_index_select(self::Tensor{T, N}, dim::Int64, index::Int64) where {T, N}
ptr = ccall((:tensor_method_index_select_int64, :libtorch_capi),
Ptr{Cvoid}, (Ptr{Cvoid}, Clonglong, Clonglong),
self.pointer, dim, index)
return Tensor{T, N}(ptr, nothing)
end

function opt_index_select(self::Tensor{T, N}, dim::Int64, i::Array{Int64}) where {T, N}
return opt_index_select(self, dim, Tensor(i))
end

function opt_index_select(self::Tensor{T, N}, dim::Int64, index::Tensor) where {T, N}
outputs__ = Int[0]
__cret = ccall((:atg_index_select, :libtorch_capi),
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Clonglong, Ptr{Cvoid}),
outputs__, self.pointer, dim, index.pointer)
return Tensor{T, N}(Ptr{Cvoid}(outputs__[1]), nothing)
end

function opt_reshape(self::Tensor{T}, shape_data::Array{Int64}) where T
outputs__ = Int[0]
shape_len = length(shape_data)
__cret = ccall((:atg_reshape, :libtorch_capi),
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cint),
outputs__, self.pointer, shape_data, shape_len)
return Tensor{T, shape_len}(Ptr{Cvoid}(outputs__[1]), nothing)
end
Loading

0 comments on commit c021717

Please sign in to comment.