Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1440] julia: porting current_context #17142

Merged
merged 9 commits into from
Dec 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion julia/src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ export Context,
cpu,
gpu,
num_gpus,
gpu_memory_info
gpu_memory_info,
current_context,
@context,
@cpu,
@gpu

# model.jl
export AbstractModel,
Expand Down
104 changes: 103 additions & 1 deletion julia/src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,89 @@ struct Context
Context(dev_type::CONTEXT_TYPE, dev_id::Integer = 0) = new(dev_type, dev_id)
end

const _default_ctx = Ref{Context}(Context(CPU, 0))

Context(dev_type::Integer, dev_id::Integer = 0) =
Context(convert(CONTEXT_TYPE, dev_type), dev_id)

Base.show(io::IO, ctx::Context) =
print(io, "$(ctx.device_type)$(ctx.device_id)")
print(io, lowercase("$(ctx.device_type)$(ctx.device_id)"))

function _with_context(dev_type::Union{Symbol,Expr}, dev_id, e::Expr)
global _default_ctx
quote
ctx = current_context()
ctx′ = Context($(esc(dev_type)), $(esc(dev_id)))
$_default_ctx[] = ctx′
try
return $(esc(e))
finally
$_default_ctx[] = ctx
end
end
end

"""
@context device_type [device_id] expr

Change the default context in the following expression.

# Examples
```jl-repl
julia> mx.@context mx.GPU begin
mx.zeros(2, 3)
end
2×3 NDArray{Float32,2} @ gpu0:
0.0f0 0.0f0 0.0f0
0.0f0 0.0f0 0.0f0

julia> @context mx.GPU mx.zeros(3, 2)
3×2 NDArray{Float32,2} @ gpu0:
0.0f0 0.0f0
0.0f0 0.0f0
0.0f0 0.0f0
```
"""
macro context(dev_type, e::Expr)
_with_context(dev_type, 0, e)
end

macro context(dev_type, dev_id, e::Expr)
_with_context(dev_type, dev_id, e)
end

for dev ∈ [:cpu, :gpu]
ctx = QuoteNode(Symbol(uppercase(string(dev))))
docstring = """
@$dev [device_id] expr

A shorthand for `@context mx.GPU`.

# Examples
```jl-repl
julia> mx.@with_gpu mx.zeros(2, 3)
2×3 NDArray{Float32,2} @ gpu0:
0.0f0 0.0f0 0.0f0
0.0f0 0.0f0 0.0f0
```
"""
@eval begin
@doc $docstring ->
macro $dev(e::Expr)
ctx = $ctx
quote
@context $ctx $(esc(e))
end
end

macro $dev(dev_id, e::Expr)
ctx = $ctx
quote
@context $ctx $(esc(dev_id)) $(esc(e))
end
end
end
end # for dev ∈ [:cpu, :gpu]

"""
cpu(dev_id)
Expand Down Expand Up @@ -86,3 +164,27 @@ function gpu_memory_info(dev_id = 0)
@mxcall :MXGetGPUMemoryInformation64 (Cint, Ref{UInt64}, Ref{UInt64}) dev_id free n
free[], n[]
end

"""
current_context()

Return the current context.

By default, `mx.cpu()` is used for all the computations
and it can be overridden by using the `@context` macro.

# Examples
```jl-repl
julia> mx.current_context()
cpu0

julia> mx.@context mx.GPU 1 begin # Context changed in the following code block
mx.current_context()
end
gpu1

julia> mx.current_context()
cpu0
```
"""
current_context() = _default_ctx[]
18 changes: 10 additions & 8 deletions julia/src/ndarray/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,45 @@ Base.similar(x::NDArray{T,N}; writable = x.writable, ctx = context(x)) where {T,
NDArray{T,N}(undef, size(x)...; writable = writable, ctx = ctx)

"""
zeros([DType], dims, [ctx::Context = cpu()])
zeros([DType], dims, ctx::Context = current_context())
zeros([DType], dims...)
zeros(x::NDArray)

Create zero-ed `NDArray` with specific shape and type.
"""
function zeros(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
function zeros(::Type{T}, dims::NTuple{N,Int},
ctx::Context = current_context()) where {N,T<:DType}
x = NDArray{T}(undef, dims..., ctx = ctx)
x[:] = zero(T)
x
end

zeros(::Type{T}, dims::Int...) where {T<:DType} = zeros(T, dims)

zeros(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
zeros(dims::NTuple{N,Int}, ctx::Context = current_context()) where N =
zeros(MX_float, dims, ctx)
zeros(dims::Int...) = zeros(dims)

zeros(x::NDArray)::typeof(x) = zeros_like(x)
Base.zeros(x::NDArray)::typeof(x) = zeros_like(x)

"""
ones([DType], dims, [ctx::Context = cpu()])
ones([DType], dims, ctx::Context = current_context())
ones([DType], dims...)
ones(x::NDArray)

Create an `NDArray` with specific shape & type, and initialize with 1.
"""
function ones(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
function ones(::Type{T}, dims::NTuple{N,Int},
ctx::Context = current_context()) where {N,T<:DType}
arr = NDArray{T}(undef, dims..., ctx = ctx)
arr[:] = one(T)
arr
end

ones(::Type{T}, dims::Int...) where T<:DType = ones(T, dims)

ones(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
ones(dims::NTuple{N,Int}, ctx::Context = current_context()) where N =
ones(MX_float, dims, ctx)
ones(dims::Int...) = ones(dims)

Expand Down Expand Up @@ -458,12 +460,12 @@ function Base.fill!(arr::NDArray, x)
end

"""
fill(x, dims, ctx=cpu())
fill(x, dims, ctx = current_context())
fill(x, dims...)

Create an `NDArray` filled with the value `x`, like `Base.fill`.
"""
function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = cpu()) where {T,N}
function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = current_context()) where {T,N}
arr = NDArray{T}(undef, dims, ctx = ctx)
arr[:] = x
arr
Expand Down
2 changes: 1 addition & 1 deletion julia/src/ndarray/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end

# UndefInitializer constructors
NDArray{T,N}(::UndefInitializer, dims::NTuple{N,Integer};
writable = true, ctx::Context = cpu()) where {T,N} =
writable = true, ctx::Context = current_context()) where {T,N} =
NDArray{T,N}(_ndarray_alloc(T, dims, ctx, false), writable)
NDArray{T,N}(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {T,N} =
NDArray{T,N}(undef, dims; kw...)
Expand Down
77 changes: 77 additions & 0 deletions julia/test/unittest/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,85 @@ function test_num_gpus()
@test num_gpus() >= 0
end

function test_context_macro()
@info "Context::@context"

@context mx.CPU 42 begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 42

@context mx.GPU 24 begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 24
end

ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 42
end

function f()
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 123
end

@context mx.GPU 123 begin
f()
end

@context mx.GPU begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 0
end

@context mx.CPU begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 0
end

@info "Context::@gpu"
@gpu 123 f()
@gpu begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 0
end
let n = 321
@gpu n begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 321
end
end

@info "Context::@cpu"
@cpu 123 begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 123
end
@cpu begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 0
end
let n = 321
@cpu n begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 321
end
end
end

@testset "Context Test" begin
test_num_gpus()
test_context_macro()
end


Expand Down
2 changes: 1 addition & 1 deletion julia/test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ function test_show()
@test occursin("1×4", str)
@test occursin("NDArray", str)
@test occursin("Int64", str)
@test occursin("CPU", str)
@test occursin("cpu", str)
@test match(r"1\s+2\s+3\s+4", str) != nothing
end

Expand Down