Skip to content

Commit 651577a

Browse files
committed
Implement most of StopWhenAny and StopWhenAll and their |, & operations.
1 parent d4c9948 commit 651577a

3 files changed

Lines changed: 231 additions & 8 deletions

File tree

docs/src/interface.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ consists of:
1515
The combination of the static information should be enough to initialize the varying data.
1616

1717
This general scheme is a guiding principle of the package, splitting information into _static_
18-
or _configuration_ types or data that allows to [`initialize`](@ref) a correspondint _variable_ data type.
18+
or _configuration_ types or data that allows to [`initialize_state`](@ref) a correspondint _variable_ data type.
1919

2020
The order of arguments is given by two ideas
2121

src/AlgorithmsInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ include("interface/interface.jl")
1818
include("stopping_criterion.jl")
1919

2020
export Algorithm, Problem, State
21-
export StoppingCriterion
22-
export StopAfter, StopAfterIteration
21+
export StoppingCriterion, StoppingCriterionState
22+
export StopAfter, StopAfterIteration, StopWhenAll, StopWhenAny
2323
export is_finished
2424
export initialize_state, initialize_state!, is_finished
2525
export get_iteration

src/stopping_criterion.jl

Lines changed: 228 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ function indicates_convergence(sc::StoppingCriterion, scs::StoppingCriterionStat
6666
return length(get_reason(sc, scs)) > 0 ? indicates_convergence(sc) : false
6767
end
6868

69-
function get_summary end
69+
function summary end
7070
@doc """
71-
get_summary(sc::StoppingCriterion, scs::StoppingCriterionState)
71+
summary(sc::StoppingCriterion, scs::StoppingCriterionState)
7272
7373
Provide a summary of the status of a stopping criterion – its parameters and whether
7474
it currently indicates to stop. It should not be longer than one line
@@ -81,7 +81,230 @@ For the [`StopAfterIteration`](@ref) criterion, the summary looks like
8181
Max Iterations (15): not reached
8282
```
8383
"""
84-
get_summary(sc::StoppingCriterion, scs::StoppingCriterionState)
84+
summary(::StoppingCriterion, ::StoppingCriterionState)
85+
86+
#
87+
#
88+
# Meta StoppingCriteria
89+
@doc raw"""
90+
StopWhenAll <: StoppingCriterion
91+
92+
store a tuple of [`StoppingCriterion`](@ref)s and indicate to stop,
93+
when _all_ indicate to stop.
94+
95+
# Constructor
96+
97+
StopWhenAll(c::NTuple{N,StoppingCriterion} where N)
98+
StopWhenAll(c::StoppingCriterion,...)
99+
"""
100+
struct StopWhenAll{TCriteria<:Tuple} <: StoppingCriterion
101+
criteria::TCriteria
102+
StopWhenAll(c::Vector{StoppingCriterion}) = new{typeof(tuple(c...))}(tuple(c...))
103+
StopWhenAll(c...) = new{typeof(c)}(c)
104+
end
105+
106+
function indicates_convergence(sc::StopWhenAll)
107+
return any(indicates_convergence(sc_i) for sc_i in sc.criteria)
108+
end
109+
110+
function show(io::IO, sc::StopWhenAll)
111+
s = ""
112+
for sc_i in sc.criteria
113+
s = s * "\n * " * replace("$(sc_i)", "\n" => "\n ") #increase indent
114+
end
115+
return print(io, "StopWhenAll with the Stopping Criteria\n$(s)")
116+
end
117+
118+
"""
119+
&(s1,s2)
120+
s1 & s2
121+
122+
Combine two [`StoppingCriterion`](@ref) within an [`StopWhenAll`](@ref).
123+
If either `s1` (or `s2`) is already an [`StopWhenAll`](@ref), then `s2` (or `s1`) is
124+
appended to the list of [`StoppingCriterion`](@ref) within `s1` (or `s2`).
125+
126+
# Example
127+
a = StopAfterIteration(200) & StopAfter(Minute(1))
128+
129+
Is the same as
130+
131+
a = StopWhenAll(StopAfterIteration(200), StopAfter(Minute(1))
132+
"""
133+
Base.:&(s1::StoppingCriterion, s2::StoppingCriterion) = StopWhenAll(s1, s2)
134+
Base.:&(s1::StoppingCriterion, s2::StopWhenAll) = StopWhenAll(s1, s2.criteria...)
135+
Base.:&(s1::StopWhenAll, s2::StoppingCriterion) = StopWhenAll(s1.criteria..., s2)
136+
Base.:&(s1::StopWhenAll, s2::StopWhenAll) = StopWhenAll(s1.criteria..., s2.criteria...)
137+
138+
@doc raw"""
139+
StopWhenAny <: StoppingCriterion
140+
141+
store an array of [`StoppingCriterion`](@ref) elements and indicates to stop,
142+
when _any_ single one indicates to stop. The `reason` is given by the
143+
concatenation of all reasons (assuming that all non-indicating return `""`).
144+
145+
# Constructors
146+
147+
StopWhenAny(c::Vector{N,StoppingCriterion} where N)
148+
StopWhenAny(c::StoppingCriterion...)
149+
"""
150+
struct StopWhenAny{TCriteria<:Tuple} <: StoppingCriterion
151+
criteria::TCriteria
152+
StopWhenAny(c::Vector{<:StoppingCriterion}) = new{typeof(tuple(c...))}(tuple(c...))
153+
StopWhenAny(c::StoppingCriterion...) = new{typeof(c)}(c)
154+
end
155+
156+
function indicates_convergence(sc::StopWhenAny)
157+
return all(indicates_convergence(ci) for ci in sc.criteria)
158+
end
159+
function show(io::IO, sc::StopWhenAny)
160+
s = ""
161+
for sc_i in sc.criteria
162+
s = s * "\n * " * replace("$(sc_i)", "\n" => "\n ") #increase indent
163+
end
164+
return print(io, "StopWhenAny with the Stopping Criteria\n$(s)")
165+
end
166+
"""
167+
|(s1,s2)
168+
s1 | s2
169+
170+
Combine two [`StoppingCriterion`](@ref) within an [`StopWhenAny`](@ref).
171+
If either `s1` (or `s2`) is already an [`StopWhenAny`](@ref), then `s2` (or `s1`) is
172+
appended to the list of [`StoppingCriterion`](@ref) within `s1` (or `s2`)
173+
174+
# Example
175+
a = StopAfterIteration(200) | StopAfter(Minute(1))
176+
177+
Is the same as
178+
179+
a = StopWhenAny(StopAfterIteration(200), StopAfter(Minute(1)))
180+
"""
181+
Base.:|(s1::StoppingCriterion, s2::StoppingCriterion) = StopWhenAny(s1, s2)
182+
Base.:|(s1::StoppingCriterion, s2::StopWhenAny) = StopWhenAny(s1, s2.criteria...)
183+
Base.:|(s1::StopWhenAny, s2::StoppingCriterion) = StopWhenAny(s1.criteria..., s2)
184+
Base.:|(s1::StopWhenAny, s2::StopWhenAny) = StopWhenAny(s1.criteria..., s2.criteria...)
185+
186+
# A common state for stopping criteria working on tuples of stopping criteria
187+
"""
188+
GroupStoppingCriterionState <: StoppingCriterionState
189+
190+
A [`StoppingCriterionState`](@ref) that groups multiple [`StoppingCriterionState`](@ref)s
191+
internally as a tuple.
192+
This is for example used in combination with [`StopWhenAny`](@ref) and [`StopWhenAny`](@ref)
193+
194+
# Constructor
195+
GroupStoppingCriterionState(c::Vector{N,StoppingCriterionState} where N)
196+
GroupStoppingCriterionState(c::StoppingCriterionState...)
197+
"""
198+
mutable struct GroupStoppingCriterionState{TCriteriaStates<:Tuple} <: StoppingCriterionState
199+
criteria_states::TCriteriaStates
200+
at_iteration::Int
201+
GroupStoppingCriterionState(c::Vector{<:StoppingCriterionState}) =
202+
new{typeof(tuple(c...))}(tuple(c...), -1)
203+
GroupStoppingCriterionState(c::StoppingCriterionState...) = new{typeof(c)}(c, -1)
204+
end
205+
206+
function initialize_state(
207+
p::Problem,
208+
a::Algorithm,
209+
sc::Union{StopWhenAll,StopWhenAny};
210+
kwargs...,
211+
)
212+
return GroupStoppingCriterionState([
213+
initialize_state(p, a, sc_i; kwargs) for sc_i in sc.criteria
214+
])
215+
end
216+
function initialize_state!(
217+
scs::GroupStoppingCriterionState,
218+
p::Problem,
219+
a::Algorithm,
220+
sc::Union{StopWhenAll,StopWhenAny};
221+
kwargs...,
222+
)
223+
for (scs_i, sc_i) in zip(scs.criteria_states, sc.criteria)
224+
initialize_state!(scs_i, p, a, sc_i; kwargs...)
225+
end
226+
scs.at_iteration = -1
227+
return scs
228+
end
229+
230+
function get_reason(sc::Union{StopWhenAll,StopWhenAny}, scs::GroupStoppingCriterionState)
231+
if scs.at_iteration >= 0
232+
return string(
233+
(
234+
get_reason(sc_i, scs_i) for
235+
(sc_i, scs_i) in zip(sc.criteria, scs.criteria_states)
236+
)...,
237+
)
238+
end
239+
return ""
240+
end
241+
242+
function summary(sc::StopWhenAny, scs::GroupStoppingCriterionState)
243+
has_stopped = (scs.at_iteration >= 0)
244+
s = has_stopped ? "reached" : "not reached"
245+
r = "Stop When _one_ of the following are fulfilled:\n"
246+
for (sc_i, scs_i) in zip(sc.criteria, scs.criteria_states)
247+
s = replace(summary(sc_i, scs_i), "\n" => "\n ")
248+
r = "$r $(s)\n"
249+
end
250+
return "$(r)Overall: $s"
251+
end
252+
function summary(sc::StopWhenAll, scs::GroupStoppingCriterionState)
253+
has_stopped = (scs.at_iteration >= 0)
254+
s = has_stopped ? "reached" : "not reached"
255+
r = "Stop When _all_ of the following are fulfilled:\n"
256+
for (sc_i, scs_i) in zip(sc.criteria, scs.criteria_states)
257+
s = replace(summary(sc_i, scs_i), "\n" => "\n ")
258+
r = "$r $(s)\n"
259+
end
260+
return "$(r)Overall: $s"
261+
end
262+
# Meta functors
263+
function (scs::GroupStoppingCriterionState)(
264+
p::Problem,
265+
a::Algorithm,
266+
s::State,
267+
sc::StopWhenAll,
268+
)
269+
k = get_iteration(s)
270+
(k == 0) && (scs.at_iteration = -1) # reset on init
271+
if all(st -> st[2](p, a, s, st[1]), zip(sc.criteria, scs.criteria_states))
272+
scs.at_iteration = k
273+
return true
274+
end
275+
return false
276+
end
277+
278+
# `_fast_any(f, tup::Tuple)`` is functionally equivalent to `any(f, tup)`` but on Julia 1.10
279+
# this implementation is faster on heterogeneous tuples
280+
@inline _fast_any(f, tup::Tuple{}) = true
281+
@inline _fast_any(f, tup::Tuple{T}) where {T} = f(tup[1])
282+
@inline function _fast_any(f, tup::Tuple)
283+
if f(tup[1])
284+
return true
285+
else
286+
return _fast_any(f, tup[2:end])
287+
end
288+
end
289+
290+
function (scs::GroupStoppingCriterionState)(
291+
p::Problem,
292+
a::Algorithm,
293+
s::State,
294+
sc::StopWhenAny,
295+
)
296+
k = get_iteration(s)
297+
(k == 0) && (c.at_iteration = -1) # reset on init
298+
if _fast_any(st -> st[2](p, a, s, st[1]), zip(sc.criteria, scs.criteria_states))
299+
c.at_iteration = k
300+
return true
301+
end
302+
return false
303+
end
304+
305+
#
306+
#
307+
# Concrete Stopping Criteria
85308

86309
@doc raw"""
87310
StopAfterIteration <: StoppingCriterion
@@ -153,7 +376,7 @@ function get_reason(sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
153376
return ""
154377
end
155378
indicates_convergence(sc::StopAfterIteration) = false
156-
function get_summary(sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
379+
function summary(sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
157380
has_stopped = (scs.at_iteration >= 0)
158381
s = has_stopped ? "reached" : "not reached"
159382
return "Max Iteration $(sc.max_iterations):\t$s"
@@ -246,7 +469,7 @@ function get_reason(sc::StopAfter, scs::StopAfterTimePeriodState)
246469
end
247470
return ""
248471
end
249-
function get_summary(sc::StopAfter, scs::StopAfterTimePeriodState)
472+
function summary(sc::StopAfter, scs::StopAfterTimePeriodState)
250473
has_stopped = (scs.at_iteration >= 0)
251474
s = has_stopped ? "reached" : "not reached"
252475
return "stopped after $(sc.threshold):\t$s"

0 commit comments

Comments
 (0)