Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For a 0.1.6 release #19

Merged
merged 5 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarlyStopping"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.1.5"
version = "0.1.6"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ criterion | description |
`NotANumber()` | Stop when `NaN` encountered |
`TimeLimit(t=0.5)` | Stop after `t` hours |
`NumberLimit(n=100)` | Stop after `n` loss updates (excl. "training losses") |
`NumberSinceBest(n=6)`| Stop after `n` loss updates (excl. "training losses") |
`Threshold(value=0.0)`| Stop when `loss < value` |
`GL(alpha=2.0)` | Stop after "Generalization Loss" exceeds `alpha` | ``GL_α``
`PQ(alpha=0.75, k=5)` | Stop after "Progress-modified GL" exceeds `alpha` | ``PQ_α``
Expand Down
4 changes: 3 additions & 1 deletion src/EarlyStopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using Statistics
import Base.+

export StoppingCriterion,
Never, NotANumber, TimeLimit, GL, Patience, UP, PQ, NumberLimit,
Never, NotANumber, TimeLimit, GL, NumberSinceBest,
Patience,
UP, PQ, NumberLimit,
Threshold, Disjunction,
criteria, stopping_time, EarlyStopper,
done!, message, needs_training_losses
Expand Down
41 changes: 41 additions & 0 deletions src/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,47 @@ done(criterion::Patience, state) = state.n_increases == criterion.n
needs_loss(::Type{<:Patience}) = true


## NUMBER SINCE BEST

"""
NumberSinceBest(; n=6)

$STOPPING_DOC

A stop is triggered when the number of calls to the control, since the
lowest value of the loss so far, is `n`.

"""
struct NumberSinceBest <: StoppingCriterion
n::Int
function NumberSinceBest(n::Int)
n > 0 ||
throw(ArgumentError("`n` must be positive. "))
return new(n)
end
end
NumberSinceBest(; n=6) = NumberSinceBest(n)

update(criterion::NumberSinceBest, loss) = (best=loss, number_since_best=0)
@inline function update(criterion::NumberSinceBest, loss, state)
best, number_since_best = state
if loss < best
best = loss
number_since_best = 0
else
number_since_best += 1
end
return (best=best, number_since_best=number_since_best)
end

# in case first loss consumed was a training loss:
update(criterion::NumberSinceBest, loss, ::Nothing) = update(criterion, loss)

done(criterion::NumberSinceBest, state) = state.number_since_best == criterion.n

needs_loss(::Type{<:NumberSinceBest}) = true


# # NUMBER LIMIT

"""
Expand Down
17 changes: 17 additions & 0 deletions test/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ end
@test !EarlyStopping.needs_training_losses(Patience())
end

@testset "NumberSinceBest" begin
@test_throws ArgumentError NumberSinceBest(n=0)
@test stopping_time(NumberSinceBest(n=6), losses) == 8
@test stopping_time(NumberSinceBest(n=5), losses) == 7
@test stopping_time(NumberSinceBest(n=4), losses) == 6
@test stopping_time(NumberSinceBest(n=3), losses) == 5
@test stopping_time(NumberSinceBest(n=2), losses) == 4
@test stopping_time(NumberSinceBest(n=1), losses) == 3

losses2 = Float64[10, 9, 8, 9, 10, 7, 10, 10, 10, 10]
@test stopping_time(NumberSinceBest(n=2), losses2) == 5
@test stopping_time(NumberSinceBest(n=3), losses2) == 9

@test EarlyStopping.needs_loss(NumberSinceBest())
@test !EarlyStopping.needs_training_losses(NumberSinceBest())
end

@testset "NumberLimit" begin
@test_throws ArgumentError NumberLimit(n=0)

Expand Down