diff --git a/Project.toml b/Project.toml index 3f0ab3c..4a8017b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EarlyStopping" uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" authors = ["Anthony D. Blaom "] -version = "0.1.1" +version = "0.1.2" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" @@ -11,7 +11,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" julia = "^1" [extras] +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["InteractiveUtils", "Test"] diff --git a/src/EarlyStopping.jl b/src/EarlyStopping.jl index 8cf7676..6172b19 100644 --- a/src/EarlyStopping.jl +++ b/src/EarlyStopping.jl @@ -6,7 +6,8 @@ import Base.+ export StoppingCriterion, Never, NotANumber, TimeLimit, GL, Patience, UP, PQ, NumberLimit, - Disjunction, criteria, stopping_time, EarlyStopper, + Threshold, Disjunction, + criteria, stopping_time, EarlyStopper, done!, message, needs_in_and_out_of_sample include("api.jl") diff --git a/src/criteria.jl b/src/criteria.jl index 11ec0e3..dde2f41 100644 --- a/src/criteria.jl +++ b/src/criteria.jl @@ -124,6 +124,8 @@ GL(; alpha=2.0) = GL(alpha) update(::GL, loss) = (loss=loss, min_loss=loss) update(::GL, loss, state) = (loss=loss, min_loss=min(loss, state.min_loss)) +# in case first loss consumed was a training loss: +update(criterion::GL, loss, ::Nothing) = update(criterion, loss) done(criterion::GL, state) = generalization_loss(state.loss, state.min_loss) > criterion.alpha @@ -295,6 +297,8 @@ update(criterion::Patience, loss) = (loss=loss, n_increases=0) end return (loss=loss, n_increases=n) end +# in case first loss consumed was a training loss: +update(criterion::Patience, loss, ::Nothing) = update(criterion, loss) done(criterion::Patience, state) = state.n_increases == criterion.n @@ -324,5 +328,30 @@ update(criterion::NumberLimit, loss) = 1 @inline function update(criterion::NumberLimit, loss, state) return state+1 end +# in case first loss consumed was a training loss: +update(criterion::NumberLimit, loss, ::Nothing) = update(criterion, loss) done(criterion::NumberLimit, state) = state == criterion.n + + +# ## THRESHOLD + +""" + Threshold(; value=0.0) + +$STOPPING_DOC + +A stop is triggered as soon as the loss drops below `value`. + +""" +mutable struct Threshold <: StoppingCriterion + value::Float64 +end +Threshold(; value=0.0) = Threshold(value) + +update(criterion::Threshold, loss) = loss +update(criterion::Threshold, loss, state) = loss +# in case first loss consumed was a training loss: +update(criterion::Threshold, loss, ::Nothing) = loss + +done(criterion::Threshold, state) = state < criterion.value diff --git a/src/disjunction.jl b/src/disjunction.jl index 9ff2798..43a5289 100644 --- a/src/disjunction.jl +++ b/src/disjunction.jl @@ -20,6 +20,7 @@ struct Disjunction{A,B} <: StoppingCriterion end end +Disjunction() = Never() Disjunction(a) = a Disjunction(a, b, c...) = Disjunction(Disjunction(a,b), c...) diff --git a/test/criteria.jl b/test/criteria.jl index e7ab009..9048536 100644 --- a/test/criteria.jl +++ b/test/criteria.jl @@ -159,4 +159,17 @@ end end end +@testset "Threshold" begin + @test Threshold().value == 0.0 + stopping_time(Threshold(2.5), Float64[12, 32, 3, 2, 5, 7]) == 4 +end + +@testset "robustness to first loss being a training loss" begin + for C in subtypes(StoppingCriterion) + losses = float.(4:-1:1) + is_training = [true, true, false, false] + stopping_time(C(), losses, is_training) + end +end + true diff --git a/test/runtests.jl b/test/runtests.jl index e8e9220..99416fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using EarlyStopping, Dates, Test +using EarlyStopping, Dates, Test, InteractiveUtils @testset "criteria.jl" begin include("criteria.jl")