From d612faf9554300e029138a0cc219685d42605c24 Mon Sep 17 00:00:00 2001 From: kasiabozek Date: Wed, 9 Dec 2015 15:59:37 +0900 Subject: [PATCH 1/3] Accuracy metric added to epoch callbacks.. --- src/callback.jl | 8 ++++---- src/model.jl | 13 +++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/callback.jl b/src/callback.jl index 9f3d85b576ff..e18184b1131a 100644 --- a/src/callback.jl +++ b/src/callback.jl @@ -110,13 +110,13 @@ end function every_n_epoch(callback :: Function, n :: Int; call_on_0 :: Bool = false) EpochCallback(n, call_on_0, callback) end -function Base.call(cb :: EpochCallback, model :: Any, state :: OptimizationState) +function Base.call{T<:Real}(cb :: EpochCallback, model :: Any, state :: OptimizationState, metric :: Vector{Tuple{Base.Symbol, T}}) if state.curr_epoch == 0 if cb.call_on_0 - cb.callback(model, state) + cb.callback(model, state, metric) end elseif state.curr_epoch % cb.frequency == 0 - cb.callback(model, state) + cb.callback(model, state, metric) end end @@ -136,7 +136,7 @@ end =# function do_checkpoint(prefix::AbstractString; frequency::Int=1, save_epoch_0=false) mkpath(dirname(prefix)) - every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state + every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state, metric save_checkpoint(model, prefix, state) end end diff --git a/src/model.jl b/src/model.jl index 009471c785f2..a8e5c49df28e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -260,7 +260,7 @@ function _create_kvstore(kv_type :: Base.Symbol, num_device :: Int, arg_params : return (kv, update_on_kvstore) end -@defstruct TrainingOptions ( +@defstruct TrainingOptions Any ( initializer :: AbstractInitializer = UniformInitializer(0.01), n_epoch :: Int = 10, eval_data :: Union{Void, AbstractDataProvider} = nothing, @@ -270,13 +270,14 @@ end callbacks :: Vector{AbstractCallback} = AbstractCallback[], ) -function _invoke_callbacks(self::FeedForward, callbacks::Vector{AbstractCallback}, - state::OptimizationState, type_filter::Type) +function _invoke_callbacks{T<:Real}(self::FeedForward, callbacks::Vector{AbstractCallback}, + state::OptimizationState, type_filter::Type; + metric::Vector{Tuple{Base.Symbol, T}} = Vector{Tuple{Base.Symbol, Real}}()) map(callbacks) do cb if isa(cb, type_filter) if type_filter == AbstractEpochCallback # epoch callback have extra access to the model object - cb(self, state) + cb(self, state, metric) else cb(state) end @@ -465,6 +466,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra end # end of one epoch time_stop = time() + metric = get(opts.eval_metric) info(format("== Epoch {1:0>3d} ==========", i_epoch)) info("## Training summary") for (name, value) in get(opts.eval_metric) @@ -514,7 +516,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra copy!(self.aux_params[name], aux_avg) end end - _invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback) + _invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback; metric=metric) end # end of all epochs end @@ -573,4 +575,3 @@ function load_checkpoint(self :: FeedForward, prefix :: AbstractString, epoch :: self.aux_params = aux_params return self end - From a5c66935c85a4f746530f4011b3e73a72de60b1e Mon Sep 17 00:00:00 2001 From: kasiabozek Date: Wed, 9 Dec 2015 16:04:20 +0900 Subject: [PATCH 2/3] Accuracy metric added to epoch callbacks. --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index a8e5c49df28e..2745310da03a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -469,7 +469,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra metric = get(opts.eval_metric) info(format("== Epoch {1:0>3d} ==========", i_epoch)) info("## Training summary") - for (name, value) in get(opts.eval_metric) + for (name, value) in metric info(format("{1:>18s} = {2:.4f}", string(name), value)) end info(format("{1:>18s} = {2:.4f} seconds", "time", time_stop-time_start)) From 49a92d1e90e09e468cdb48a99f9898dee52946b9 Mon Sep 17 00:00:00 2001 From: kasiabozek Date: Wed, 9 Dec 2015 17:06:23 +0900 Subject: [PATCH 3/3] Accuracy metric added to epoch callbacks. --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 2745310da03a..d4b492333f9f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -260,7 +260,7 @@ function _create_kvstore(kv_type :: Base.Symbol, num_device :: Int, arg_params : return (kv, update_on_kvstore) end -@defstruct TrainingOptions Any ( +@defstruct TrainingOptions ( initializer :: AbstractInitializer = UniformInitializer(0.01), n_epoch :: Int = 10, eval_data :: Union{Void, AbstractDataProvider} = nothing,