Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Training Halts when Using CuArrarys #691

Closed
lpjiang97 opened this issue Apr 22, 2020 · 6 comments
Closed

Training Halts when Using CuArrarys #691

lpjiang97 opened this issue Apr 22, 2020 · 6 comments
Labels

Comments

@lpjiang97
Copy link

lpjiang97 commented Apr 22, 2020

I'm fairly new to CuArrays and Flux, and I met this problem of having halted training after some epochs. There is no CUDA out of memory error, but the usage is extremely high for this simple linear model (99.97% on a 1080 Ti). The code would sometimes finish all 500 epochs without problems, but other times halt around Epoch 150.

using LinearAlgebra
using Flux
using CuArrays, CUDAnative
using Flux.Optimise: update!
using Flux: crossentropy

device!(1)
CuArrays.allowscalar(false)
pred_loss(x, y) = sum((x .- y) .^ 2)

# dimens
B = 250
linear = Dense(400, 144) |> gpu
# norm
linear.W .= linear.W ./ sqrt.(sum(linear.W .^ 2, dims=1));
# training
E = 500
opt_U = Descent(0.01)
for e = 1:E
    running_l = 0
    c = 0
    for b = 1:100
        y = rand(144, B) |> gpu
        R, = zeros(400, size(y)[2]) |> gpu
        l = 0
        grads = gradient(params(linear.W)) do
            l = pred_loss(y, linear(R))
            running_l += l
            return l
        end
        update!(opt_U, linear.W, grads[linear.W])
        linear.W .= linear.W ./ sqrt.(sum(linear.W .^ 2, dims=1))
        c += 1
    end
    println("Epoch: $e, Running loss: $(running_l / c)")
end

I'm having this problem on Ubuntu 18.04, using CuArrays v 2.1.0. Would appreciate some pointers on this.

@lpjiang97 lpjiang97 added the bug label Apr 22, 2020
@maleadt
Copy link
Member

maleadt commented Apr 23, 2020

How does training halt? You need to be getting some kind of error or reason it halts, right?

Can you profile the code and see where the time goes? A typical problem is memory, where GPUs have much fewer RAM which doesn't compose well with Julia's GC and running close to the memory limit. Yours has 11GB though, so unless your model is huge that should generally work fine.

Also, which version of Julia?

@lpjiang97
Copy link
Author

lpjiang97 commented Apr 23, 2020

How does training halt? You need to be getting some kind of error or reason it halts, right?

The println statement (second last line) will stop being printed out, I have tried giving it about an hour, no updates.

Can you profile the code and see where the time goes? A typical problem is memory, where GPUs have much fewer RAM which doesn't compose well with Julia's GC and running close to the memory limit. Yours has 11GB though, so unless your model is huge that should generally work fine.

When I look at CuArrays.memory_status(), the usage is very high (99.97%), which I found very weird. I have a similar model written in PyTorch that takes only about 500 MB.

Also, which version of Julia?

I have tested this on Julia 1.3.1 and Julia 1.4.1 and both have this problem. I also wonder if this is related to #350 . But I have tried taking out sqrt. but have the same issue.

Update: if I call GC.gc() at the end of each epoch, the problem goes away.

@maleadt
Copy link
Member

maleadt commented Apr 23, 2020

Update: if I call GC.gc() at the end of each epoch, the problem goes away.

Ah, so that problem again. I thought training exited, but it hangs, which is consistent with the GC taking up all time. This is a tough problem, but it's good to have another (small-ish) reproducer.

You can also try using the new, WIP, memory pool: JULIA_CUDA_MEMORY_POOL=split. Improves performance in some workloads, but ultimately still falls back on the Julia GC so might have the same problem with your use case.

@colinxs
Copy link

colinxs commented Apr 23, 2020

Hey @maleadt, I work with @lpjiang97 and spent a bit looking into this. For what it's worth, I was able to replicate this on my machine (information below) 3/5 attempts, each in a fresh Julia session. When it does lock up, stacktrace shows that it's waiting on the lock in either alloc or free:

Stacktrace:                                                                             
 [1] top-level scope at /home/colinxs/workspace/dev/Experiments/flux/foo/debug0.jl:29   
 [2] lock(::Base.Threads.SpinLock) at ./locks-mt.jl:71                                  
 [3] macro expansion at ./lock.jl:181 [inlined]                                         
 [4] free(::CUDAdrv.CuPtr{Nothing}) at /home/colinxs/.julia/packages/CuArrays/4Q1BY/src/
memory/binned.jl:393                                                                    
 [5] macro expansion at /home/colinxs/.julia/packages/TimerOutputs/NvIUx/src/TimerOutput
.jl:245 [inlined]                                                                       
 [6] macro expansion at /home/colinxs/.julia/packages/CuArrays/4Q1BY/src/memory.jl:231 [
inlined]                                                                                
 [7] macro expansion at ./util.jl:234 [inlined]                                         
 [8] free at /home/colinxs/.julia/packages/CuArrays/4Q1BY/src/memory.jl:230 [inlined]   
 [9] _unsafe_free!(::CuArray{Float32,2,Nothing}) at /home/colinxs/.julia/packages/CuArra
ys/4Q1BY/src/array.jl:51                                                                
 [10] unsafe_free!(::CuArray{Float32,2,Nothing}) at /home/colinxs/.julia/packages/CuArra
ys/4Q1BY/src/array.jl:40     

Single GPU (1050)

julia> versioninfo()
Julia Version 1.4.1
Commit 381693d3df* (2020-04-14 17:20 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-8.0.1 (ORCJIT, skylake)
Environment:
  JULIA_DOWNLOAD = /home/colinxs/pkg/installed/julia
  JULIA_NUM_THREADS = 6
  JULIA_PKG_DEVDIR = /home/colinxs/workspace/juliadev

julia> Pkg.status()
Status `~/workspace/dev/Experiments/flux/foo/Project.toml`
  [3895d2a7] CUDAapi v4.0.0
  [be33ccc6] CUDAnative v3.0.4
  [3a865a2d] CuArrays v2.1.0
  [587475ba] Flux v0.10.4

@colinxs
Copy link

colinxs commented Apr 23, 2020

I should've the open issues first, it appears you're already well aware of this: #685

@maleadt
Copy link
Member

maleadt commented Apr 24, 2020

Correct, I suspected a performance issue but the backtrace is useful in identifying the actual issue. I'll have a look at the deadlock, since a couple of users have been running into this.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

3 participants