Skip to content

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
paulxshen committed Dec 31, 2024
1 parent ff6a466 commit 00d63cb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
8 changes: 5 additions & 3 deletions src/pic/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ function picrun(path; gpuarray=nothing, kw...)
end
end
model = models[1]
opt = AreaChangeOptimiser(model)
opt = AreaChangeOptimiser(model; minchange=0.001)
opt_state = Flux.setup(opt, model)
println("starting optimization... first iter will be slow due to adjoint compilation.")
img = nothing
Expand Down Expand Up @@ -232,7 +232,7 @@ function picrun(path; gpuarray=nothing, kw...)
# T * dϕ / π
end
else
@time global l, (dldm,) = Flux.withgradient(model) do model
function f(model)
models = [model]
res = make_pic_sim_prob(runs, run_probs, lb, dl,
designs, design_config, models, ;
Expand Down Expand Up @@ -282,6 +282,9 @@ function picrun(path; gpuarray=nothing, kw...)
println(" weighted total loss $l")
l
end

@time global l, (dldm,) = Flux.withgradient(f, model)

end
@assert !isnothing(dldm)
if !isnothing(stoploss) && l < stoploss
Expand Down Expand Up @@ -325,7 +328,6 @@ function picrun(path; gpuarray=nothing, kw...)
opt.maxchange = 0.001 + relu.(l - [0.1, 0.3, 0.7]) [0.01, 0.01, 0.01]
Jello.update_loss!(opt, l)
Flux.update!(opt_state, model, dldm)# |> gpu)
repair!(model)
end
if framerate > 0
make_pic_sim_prob(runs, run_probs, lb, dl,
Expand Down
6 changes: 3 additions & 3 deletions src/pictest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ENV["JULIA_PKG_PRECOMPILE_AUTO"] = 0
# picrun(joinpath("runs", "bend_R5"))
# picrun(joinpath("runs", "mode_converter"))
# picrun(joinpath("runs", "demux"))
# picrun(joinpath("runs", "splitter"))
picrun(joinpath("runs", "splitter"))

using CUDA
picrun(joinpath("runs", "tiny"); gpuarray=cu)
# using CUDA
# picrun(joinpath("runs", "tiny"); gpuarray=cu)

0 comments on commit 00d63cb

Please sign in to comment.