Skip to content

Commit 1aeb445

Browse files
committed
A few points from the discussions
* remove _fast_any * get_reason now returns nothing if the stopping criterion is not finished * summary is now Base.summary.
1 parent 3460c22 commit 1aeb445

1 file changed

Lines changed: 16 additions & 25 deletions

File tree

src/stopping_criterion.jl

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,11 @@ If so it returns whether `sc` itself indicates convergence, otherwise it returns
6363
since the algorithm has then not yet stopped.
6464
"""
6565
function indicates_convergence(sc::StoppingCriterion, scs::StoppingCriterionState)
66-
return length(get_reason(sc, scs)) > 0 ? indicates_convergence(sc) : false
66+
return isnothing(get_reason(sc, scs)) ? indicates_convergence(sc) : false
6767
end
6868

69-
function summary end
7069
@doc """
71-
summary(sc::StoppingCriterion, scs::StoppingCriterionState)
70+
summary(io::IO, sc::StoppingCriterion, scs::StoppingCriterionState)
7271
7372
Provide a summary of the status of a stopping criterion – its parameters and whether
7473
it currently indicates to stop. It should not be longer than one line
@@ -81,7 +80,7 @@ For the [`StopAfterIteration`](@ref) criterion, the summary looks like
8180
Max Iterations (15): not reached
8281
```
8382
"""
84-
summary(::StoppingCriterion, ::StoppingCriterionState)
83+
Base.summary(io::IO, ::StoppingCriterion, ::StoppingCriterionState)
8584

8685
#
8786
#
@@ -106,7 +105,7 @@ function indicates_convergence(sc::StopWhenAll)
106105
return any(indicates_convergence, sc.criteria)
107106
end
108107

109-
function show(io::IO, sc::StopWhenAll)
108+
function Base.show(io::IO, sc::StopWhenAll)
110109
s = ""
111110
for sc_i in sc.criteria
112111
s = s * "\n * " * replace("$(sc_i)", "\n" => "\n ") #increase indent
@@ -155,7 +154,7 @@ end
155154
function indicates_convergence(sc::StopWhenAny)
156155
return all(indicates_convergence, sc.criteria)
157156
end
158-
function show(io::IO, sc::StopWhenAny)
157+
function Base.show(io::IO, sc::StopWhenAny)
159158
s = ""
160159
for sc_i in sc.criteria
161160
s = s * "\n * " * replace("$(sc_i)", "\n" => "\n ") #increase indent
@@ -235,10 +234,10 @@ function get_reason(sc::Union{StopWhenAll,StopWhenAny}, scs::GroupStoppingCriter
235234
)...,
236235
)
237236
end
238-
return ""
237+
return nothing
239238
end
240239

241-
function summary(sc::StopWhenAny, scs::GroupStoppingCriterionState)
240+
function Base.summary(io::IO, sc::StopWhenAny, scs::GroupStoppingCriterionState)
242241
has_stopped = (scs.at_iteration >= 0)
243242
s = has_stopped ? "reached" : "not reached"
244243
r = "Stop When _one_ of the following are fulfilled:\n"
@@ -248,15 +247,15 @@ function summary(sc::StopWhenAny, scs::GroupStoppingCriterionState)
248247
end
249248
return "$(r)Overall: $s"
250249
end
251-
function summary(sc::StopWhenAll, scs::GroupStoppingCriterionState)
250+
function Base.summary(io::IO, sc::StopWhenAll, scs::GroupStoppingCriterionState)
252251
has_stopped = (scs.at_iteration >= 0)
253252
s = has_stopped ? "reached" : "not reached"
254253
r = "Stop When _all_ of the following are fulfilled:\n"
255254
for (sc_i, scs_i) in zip(sc.criteria, scs.criteria_states)
256255
s = replace(summary(sc_i, scs_i), "\n" => "\n ")
257256
r = "$r $(s)\n"
258257
end
259-
return "$(r)Overall: $s"
258+
return print(io, "$(r)Overall: $s")
260259
end
261260
# Meta functors
262261
function (scs::GroupStoppingCriterionState)(
@@ -274,14 +273,6 @@ function (scs::GroupStoppingCriterionState)(
274273
return false
275274
end
276275

277-
# `_fast_any(f, tup::Tuple)`` is functionally equivalent to `any(f, tup)`` but on Julia 1.10
278-
# this implementation is faster on heterogeneous tuples
279-
@inline _fast_any(f, tup::Tuple{}) = true
280-
@inline _fast_any(f, tup::Tuple{T}) where {T} = f(tup[1])
281-
@inline function _fast_any(f, tup::Tuple)
282-
f(first(tup[1])) || _fast_any(f, Base.tail(tup))
283-
end
284-
285276
function (scs::GroupStoppingCriterionState)(
286277
p::Problem,
287278
a::Algorithm,
@@ -290,7 +281,7 @@ function (scs::GroupStoppingCriterionState)(
290281
)
291282
k = get_iteration(s)
292283
(k == 0) && (c.at_iteration = -1) # reset on init
293-
if _fast_any(st -> st[2](p, a, s, st[1]), zip(sc.criteria, scs.criteria_states))
284+
if any(st -> st[2](p, a, s, st[1]), zip(sc.criteria, scs.criteria_states))
294285
c.at_iteration = k
295286
return true
296287
end
@@ -368,13 +359,13 @@ function get_reason(sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
368359
if scs.at_iteration >= sc.max_iterations
369360
return "At iteration $(scs.at_iteration) the algorithm reached its maximal number of iterations ($(sc.max_iterations)).\n"
370361
end
371-
return ""
362+
return nothing
372363
end
373364
indicates_convergence(sc::StopAfterIteration) = false
374-
function summary(sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
365+
function Base.summary(io::IO, sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
375366
has_stopped = (scs.at_iteration >= 0)
376367
s = has_stopped ? "reached" : "not reached"
377-
return "Max Iteration $(sc.max_iterations):\t$s"
368+
return print(io, "Max Iteration $(sc.max_iterations):\t$s")
378369
end
379370

380371
"""
@@ -462,11 +453,11 @@ function get_reason(sc::StopAfter, scs::StopAfterTimePeriodState)
462453
if (scs.at_iteration >= 0)
463454
return "After iteration $(scs.at_iteration) the algorithm ran for $(floor(scs.time, typeof(sc.threshold))) (threshold: $(sc.threshold)).\n"
464455
end
465-
return ""
456+
return nothing
466457
end
467-
function summary(sc::StopAfter, scs::StopAfterTimePeriodState)
458+
function Base.summary(io::IO, sc::StopAfter, scs::StopAfterTimePeriodState)
468459
has_stopped = (scs.at_iteration >= 0)
469460
s = has_stopped ? "reached" : "not reached"
470-
return "stopped after $(sc.threshold):\t$s"
461+
return print(io, "stopped after $(sc.threshold):\t$s")
471462
end
472463
indicates_convergence(sc::StopAfter) = false

0 commit comments

Comments
 (0)