Skip to content

Commit da9f1e4

Browse files
committed
Fix a few bugs in Summary and write more tests.
1 parent c5781ef commit da9f1e4

2 files changed

Lines changed: 108 additions & 93 deletions

File tree

src/stopping_criterion.jl

Lines changed: 43 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,10 @@ is_finished!(::Problem, ::Algorithm, ::State, ::StoppingCriterion, ::StoppingCri
105105

106106
@doc """
107107
summary(io::IO, stopping_criterion::StoppingCriterion, stopping_criterion_state::StoppingCriterionState)
108+
summary(stopping_criterion::StoppingCriterion, stopping_criterion_state::StoppingCriterionState)
108109
109110
Provide a summary of the status of a stopping criterion – its parameters and whether
110-
it currently indicates to stop.
111+
it currently indicates to stop.ìo` and generates a string from that.
111112
112113
# Example
113114
@@ -119,6 +120,15 @@ Max Iterations (15): not reached
119120
"""
120121
Base.summary(io::IO, ::StoppingCriterion, ::StoppingCriterionState)
121122

123+
function Base.summary(
124+
stopping_criterion::StoppingCriterion,
125+
stopping_criterion_state::StoppingCriterionState
126+
)
127+
io = IOBuffer()
128+
summary(io, stopping_criterion, stopping_criterion_state)
129+
return String(take!(io))
130+
end
131+
122132
#
123133
#
124134
# Meta StoppingCriteria
@@ -237,8 +247,6 @@ This is for example used in combination with [`StopWhenAny`](@ref) and [`StopWhe
237247
mutable struct GroupStoppingCriterionState{TCriteriaStates <: Tuple} <: StoppingCriterionState
238248
criteria_states::TCriteriaStates
239249
at_iteration::Int
240-
GroupStoppingCriterionState(c::Vector{<:StoppingCriterionState}) =
241-
new{typeof(tuple(c...))}(tuple(c...), -1)
242250
GroupStoppingCriterionState(c::StoppingCriterionState...) = new{typeof(c)}(c, -1)
243251
end
244252

@@ -247,15 +255,13 @@ function get_reason(
247255
stopping_criterion_states::GroupStoppingCriterionState,
248256
)
249257
stopping_criterion_states.at_iteration < 0 && return nothing
250-
criteria = stop_when.criteriaq
258+
criteria = stop_when.criteria
251259
stopping_criterion_states = stopping_criterion_states.criteria_states
252260
return join(Iterators.map(get_reason, criteria, stopping_criterion_states))
253261
end
254262

255263
function initialize_state(
256-
problem::Problem,
257-
algorithm::Algorithm,
258-
stop_when::Union{StopWhenAll, StopWhenAny};
264+
problem::Problem, algorithm::Algorithm, stop_when::Union{StopWhenAll, StopWhenAny};
259265
kwargs...,
260266
)
261267
return GroupStoppingCriterionState(
@@ -266,19 +272,14 @@ function initialize_state(
266272
)
267273
end
268274
function initialize_state!(
269-
problem::Problem,
270-
algorithm::Algorithm,
271-
stop_when::Union{StopWhenAll, StopWhenAny},
275+
problem::Problem, algorithm::Algorithm, stop_when::Union{StopWhenAll, StopWhenAny},
272276
stopping_criterion_states::GroupStoppingCriterionState;
273277
kwargs...,
274278
)
275279
for (stopping_criterion_state, stopping_criterion) in
276280
zip(stopping_criterion_states.criteria_states, stop_when.criteria)
277281
initialize_state!(
278-
problem,
279-
algorithm,
280-
stopping_criterion,
281-
stopping_criterion_state;
282+
problem, algorithm, stopping_criterion, stopping_criterion_state;
282283
kwargs...,
283284
)
284285
end
@@ -287,11 +288,8 @@ function initialize_state!(
287288
end
288289

289290
function is_finished(
290-
problem::Problem,
291-
algorithm::Algorithm,
292-
state::State,
293-
stop_when_all::StopWhenAll,
294-
stopping_criterion_states::GroupStoppingCriterionState,
291+
problem::Problem, algorithm::Algorithm, state::State,
292+
stop_when_all::StopWhenAll, stopping_criterion_states::GroupStoppingCriterionState,
295293
)
296294
k = state.iteration
297295
(k == 0) && (stopping_criterion_states.at_iteration = -1) # reset on init
@@ -304,11 +302,8 @@ function is_finished(
304302
return false
305303
end
306304
function is_finished!(
307-
problem::Problem,
308-
algorithm::Algorithm,
309-
state::State,
310-
stop_when_all::StopWhenAll,
311-
stopping_criterion_states::GroupStoppingCriterionState,
305+
problem::Problem, algorithm::Algorithm, state::State,
306+
stop_when_all::StopWhenAll, stopping_criterion_states::GroupStoppingCriterionState,
312307
)
313308
k = state.iteration
314309
(k == 0) && (stopping_criterion_states.at_iteration = -1) # reset on init
@@ -323,11 +318,8 @@ function is_finished!(
323318
end
324319

325320
function is_finished(
326-
problem::Problem,
327-
algorithm::Algorithm,
328-
state::State,
329-
stop_when_any::StopWhenAny,
330-
stopping_criterion_states::GroupStoppingCriterionState,
321+
problem::Problem, algorithm::Algorithm, state::State,
322+
stop_when_any::StopWhenAny, stopping_criterion_states::GroupStoppingCriterionState,
331323
)
332324
k = state.iteration
333325
(k == 0) && (stopping_criterion_states.at_iteration = -1) # reset on init
@@ -340,11 +332,8 @@ function is_finished(
340332
return false
341333
end
342334
function is_finished!(
343-
problem::Problem,
344-
algorithm::Algorithm,
345-
state::State,
346-
stop_when_any::StopWhenAny,
347-
stopping_criterion_states::GroupStoppingCriterionState,
335+
problem::Problem, algorithm::Algorithm, state::State,
336+
stop_when_any::StopWhenAny, stopping_criterion_states::GroupStoppingCriterionState,
348337
)
349338
k = state.iteration
350339
(k == 0) && (stopping_criterion_states.at_iteration = -1) # reset on init
@@ -360,33 +349,31 @@ end
360349

361350
function Base.summary(
362351
io::IO,
363-
stop_when_any::StopWhenAny,
364-
stopping_criterion_states::GroupStoppingCriterionState,
352+
stop_when_any::StopWhenAny, stopping_criterion_states::GroupStoppingCriterionState,
365353
)
366354
has_stopped = (stopping_criterion_states.at_iteration >= 0)
367355
s = has_stopped ? "reached" : "not reached"
368-
r = "Stop When _one_ of the following are fulfilled:\n"
356+
r = "Stop when _one_ of the following are fulfilled:\n"
369357
for (stopping_criterion, stopping_criterion_state) in
370358
zip(stop_when_any.criteria, stopping_criterion_states.criteria_states)
371-
s = replace(summary(stopping_criterion, stopping_criterion_state), "\n" => "\n ")
372-
r = "$r $(s)\n"
359+
t = replace(summary(stopping_criterion, stopping_criterion_state), "\n" => "\n ")
360+
r = "$(r) $(t)\n"
373361
end
374-
return print(io, "$(r)Overall: $s")
362+
return print(io, "$(r)Overall: $(s)")
375363
end
376364
function Base.summary(
377365
io::IO,
378-
stop_when_all::StopWhenAll,
379-
stopping_criterion_states::GroupStoppingCriterionState,
366+
stop_when_all::StopWhenAll, stopping_criterion_states::GroupStoppingCriterionState,
380367
)
381368
has_stopped = (stopping_criterion_states.at_iteration >= 0)
382369
s = has_stopped ? "reached" : "not reached"
383-
r = "Stop When _all_ of the following are fulfilled:\n"
370+
r = "Stop when _all_ of the following are fulfilled:\n"
384371
for (stopping_criterion, stopping_criterion_state) in
385372
zip(stop_when_all.criteria, stopping_criterion_states.criteria_states)
386-
s = replace(summary(stopping_criterion, stopping_criterion_state), "\n" => "\n ")
387-
r = "$r $(s)\n"
373+
t = replace(summary(stopping_criterion, stopping_criterion_state), "\n" => "\n ")
374+
r = "$(r) $(t)\n"
388375
end
389-
return print(io, "$(r)Overall: $s")
376+
return print(io, "$(r)Overall: $(s)")
390377
end
391378

392379
#
@@ -429,12 +416,9 @@ mutable struct DefaultStoppingCriterionState <: StoppingCriterionState
429416
DefaultStoppingCriterionState() = new(-1)
430417
end
431418

432-
initialize_state(::Problem, ::Algorithm, ::StopAfterIteration; kwargs...) =
433-
DefaultStoppingCriterionState()
419+
initialize_state(::Problem, ::Algorithm, ::StopAfterIteration; kwargs...) = DefaultStoppingCriterionState()
434420
function initialize_state!(
435-
::Problem,
436-
::Algorithm,
437-
::StopAfterIteration,
421+
::Problem, ::Algorithm, ::StopAfterIteration,
438422
stopping_criterion_state::DefaultStoppingCriterionState;
439423
kwargs...,
440424
)
@@ -444,18 +428,14 @@ end
444428

445429

446430
function is_finished(
447-
::Problem,
448-
::Algorithm,
449-
state::State,
431+
::Problem, ::Algorithm, state::State,
450432
stop_after_iteration::StopAfterIteration,
451433
stopping_criterion_state::DefaultStoppingCriterionState,
452434
)
453435
return state.iteration >= stop_after_iteration.max_iterations
454436
end
455437
function is_finished!(
456-
::Problem,
457-
::Algorithm,
458-
state::State,
438+
::Problem, ::Algorithm, state::State,
459439
stop_after_iteration::StopAfterIteration,
460440
stopping_criterion_state::DefaultStoppingCriterionState,
461441
)
@@ -541,9 +521,7 @@ initialize_state(::Problem, ::Algorithm, ::StopAfter; kwargs...) =
541521
StopAfterTimePeriodState()
542522

543523
function initialize_state!(
544-
::Problem,
545-
::Algorithm,
546-
::StopAfter,
524+
::Problem, ::Algorithm, ::StopAfter,
547525
stopping_criterion_state::StopAfterTimePeriodState;
548526
kwargs...,
549527
)
@@ -554,22 +532,16 @@ function initialize_state!(
554532
end
555533

556534
function is_finished(
557-
::Problem,
558-
::Algorithm,
559-
state::State,
560-
stop_after::StopAfter,
561-
stop_after_state::StopAfterTimePeriodState,
535+
::Problem, ::Algorithm, state::State,
536+
stop_after::StopAfter, stop_after_state::StopAfterTimePeriodState,
562537
)
563538
k = state.iteration
564539
# Just check whether the (last recorded) time is beyond the threshold
565540
return (k > 0 && (stop_after_state.time > Nanosecond(stop_after.threshold)))
566541
end
567542
function is_finished!(
568-
::Problem,
569-
::Algorithm,
570-
state::State,
571-
stop_after::StopAfter,
572-
stop_after_state::StopAfterTimePeriodState,
543+
::Problem, ::Algorithm, state::State,
544+
stop_after::StopAfter, stop_after_state::StopAfterTimePeriodState,
573545
)
574546
k = state.iteration
575547
if value(stop_after_state.start) == 0 || k <= 0 # (re)start timer
@@ -596,8 +568,7 @@ function get_reason(
596568
end
597569
function Base.summary(
598570
io::IO,
599-
stop_after::StopAfter,
600-
stopping_criterion_state::StopAfterTimePeriodState,
571+
stop_after::StopAfter, stopping_criterion_state::StopAfterTimePeriodState,
601572
)
602573
has_stopped = (stopping_criterion_state.at_iteration >= 0)
603574
s = has_stopped ? "reached" : "not reached"

test/stopping_criterion.jl

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ problem = AIT.DummyProblem()
1414
s1_state = initialize_state(problem, algorithm, s1)
1515
@test !indicates_convergence(s1, s1_state)
1616
state_finished = AIT.DummyState(s1_state, 2)
17-
state_not_finished = AIT.DummyState(s1_state, 1)
17+
alg_state = AIT.DummyState(s1_state, 1)
1818
@test is_finished(problem, algorithm, state_finished)
19-
@test !is_finished(problem, algorithm, state_not_finished)
19+
@test !is_finished(problem, algorithm, alg_state)
2020
end
2121

2222
@testset "StopAfter" begin
@@ -26,43 +26,87 @@ end
2626

2727
algorithm = AIT.DummyAlgorithm(s1)
2828
s1_state = initialize_state(problem, algorithm, s1)
29-
state_not_finished = AIT.DummyState(s1_state, 1)
30-
@test !is_finished(problem, algorithm, state_not_finished)
29+
alg_state = AIT.DummyState(s1_state, 1)
30+
@test !is_finished(problem, algorithm, alg_state)
3131
s1_state.time = Second(2)
32-
@test is_finished(problem, algorithm, state_not_finished)
32+
@test is_finished(problem, algorithm, alg_state)
3333
end
3434

3535
@testset "StopWhenAll" begin
36-
s1 = StopAfterIteration(2) & StopAfter(Second(1))
37-
s1b = StopWhenAll([StopAfterIteration(2), StopAfter(Second(1))])
36+
c1 = StopAfterIteration(2)
37+
c2 = StopAfter(Nanosecond(2))
38+
c3 = StopAfterIteration(3)
39+
s1 = c1 & c2
40+
s1b = StopWhenAll([c1, c2])
3841
@test s1 == s1b
3942
@test s1 isa StoppingCriterion
4043
@test sprint((io, x) -> show(io, MIME"text/plain"(), x), s1) ==
41-
"StopWhenAll with the Stopping Criteria:\n StopAfterIteration(2)\n StopAfter(Second(1))"
42-
44+
"StopWhenAll with the Stopping Criteria:\n StopAfterIteration(2)\n StopAfter(Nanosecond(2))"
4345
algorithm = AIT.DummyAlgorithm(s1)
4446
s1_state = initialize_state(problem, algorithm, s1)
45-
state_not_finished = AIT.DummyState(s1_state, 1)
46-
@test !is_finished(problem, algorithm, state_not_finished)
47-
s1_state.criteria_states[2].time = Second(2)
48-
@test !is_finished(problem, algorithm, state_not_finished)
49-
state_not_finished.iteration = 2
50-
@test is_finished(problem, algorithm, state_not_finished)
47+
48+
s1_str = summary(s1, s1_state)
49+
@test contains(s1_str, "Stop when _all_ ")
50+
@test contains(s1_str, "Overall: not reached")
51+
52+
@test isnothing(AlgorithmsInterface.get_reason(s1, s1_state))
53+
alg_state = AIT.DummyState(s1_state, 1)
54+
@test !is_finished(problem, algorithm, alg_state)
55+
# Fake start timer
56+
s1_state.criteria_states[2].start = Nanosecond(time_ns())
57+
s1_state.criteria_states[2].time = Nanosecond(7)
58+
# just time is not enough
59+
@test !is_finished(problem, algorithm, alg_state)
60+
alg_state.iteration = 2
61+
# but now both are
62+
@test is_finished!(problem, algorithm, alg_state)
5163
@test !indicates_convergence(s1)
64+
# check that reset works (a) check with modification
65+
@test is_finished!(problem, algorithm, alg_state)
66+
@test alg_state.stopping_criterion_state.at_iteration > 0
67+
AlgorithmsInterface.initialize_state!(problem, algorithm, s1, s1_state)
68+
@test s1_state.criteria_states[1].at_iteration == -1
69+
# Different constructors
70+
s2 = c1 & c2 & c3
71+
@test s1 & c3 == s2
72+
@test c1 & (c2 & c3) == s2
73+
@test s1 & s2 isa StopWhenAll
5274
end
5375

5476
@testset "StopWhenAny" begin
55-
s1 = StopAfterIteration(2) | StopAfter(Second(1))
77+
c1 = StopAfterIteration(2)
78+
c2 = StopAfter(Second(1))
79+
c3 = StopAfterIteration(3)
80+
81+
s1 = c1 | c2
5682
@test s1 isa StoppingCriterion
83+
@test s1 == StopWhenAny([c1, c2])
5784
@test sprint((io, x) -> show(io, MIME"text/plain"(), x), s1) ==
5885
"StopWhenAny with the Stopping Criteria:\n StopAfterIteration(2)\n StopAfter(Second(1))"
86+
@test !indicates_convergence(s1)
5987

6088
algorithm = AIT.DummyAlgorithm(s1)
6189
s1_state = initialize_state(problem, algorithm, s1)
62-
state_not_finished = AIT.DummyState(s1_state, 1)
63-
@test !is_finished(problem, algorithm, state_not_finished)
90+
91+
s1_str = summary(s1, s1_state)
92+
@test contains(s1_str, "Stop when _one_ ")
93+
@test contains(s1_str, "Overall: not reached")
94+
95+
@test isnothing(AlgorithmsInterface.get_reason(s1, s1_state))
96+
alg_state = AIT.DummyState(s1_state, 1)
97+
@test !is_finished(problem, algorithm, alg_state)
6498
s1_state.criteria_states[2].time = Second(2)
65-
@test is_finished(problem, algorithm, state_not_finished)
66-
state_not_finished.iteration = 2
67-
@test is_finished(problem, algorithm, state_not_finished)
99+
@test is_finished(problem, algorithm, alg_state)
100+
alg_state.iteration = 2
101+
@test is_finished(problem, algorithm, alg_state)
102+
# check that reset works (a) check with modification
103+
@test is_finished!(problem, algorithm, alg_state)
104+
@test alg_state.stopping_criterion_state.at_iteration > 0
105+
AlgorithmsInterface.initialize_state!(problem, algorithm, s1, s1_state)
106+
@test s1_state.criteria_states[1].at_iteration == -1
107+
# Different constructors
108+
s2 = c1 | c2 | c3
109+
@test s1 | c3 == s2
110+
@test c1 | (c2 | c3) == s2
111+
@test s1 | s2 isa StopWhenAny
68112
end

0 commit comments

Comments
 (0)