Skip to content

Commit d4c9948

Browse files
committed
finish redesign.
1 parent 8958870 commit d4c9948

4 files changed

Lines changed: 104 additions & 54 deletions

File tree

src/interface/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ By default this method continues to call [`solve!`](@ref).
4343
"""
4444
function solve(p::Problem, a::Algorithm; kwargs...)
4545
s = initialize_state(p, a; kwargs...)
46-
return solve!(p, a, s; kwargs...)
46+
return solve!(s, p, a; kwargs...)
4747
end
4848

4949
@doc """
@@ -70,4 +70,4 @@ function step! end
7070
Perform the current step of an [`Algorithm`](@ref) `a` solving [`Problem`](@ref) `p`
7171
modifying the algorithms [`State`](@ref) `s`.
7272
"""
73-
step!(state::State, p::Problem, a::Algorithm)
73+
step!(s::State, p::Problem, a::Algorithm)

src/interface/state.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ Return the [`StoppingCriterionState`](@ref) of the [`State`](@ref) `s`.
6868
6969
The default assumes that the criterion is stored in `s.stopping_criterion_state`.
7070
"""
71-
get_stopping_criterion(s::State) = s.stopping_criterion_state
71+
get_stopping_criterion_state(s::State) = s.stopping_criterion_state

src/stopping_criterion.jl

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
An abstract type to represent a stopping criterion.
55
66
A concrete [`StoppingCriterion`](@ref) `sc` should also implement a
7-
[`initialize(sc::StoppingCriterion)`](@ref) function to create its accompaying
7+
[`initialize_state(problem::Problem, algorithm::Algorithm, sc::StoppingCriterion; kwargs...)`](@ref) function to create its accompanying
88
[`StoppingCriterionState`](@ref).
9+
as well as the corresponting mutating variant to reset such a [`StoppingCriterionState`](@ref).
10+
911
It should usually implement
1012
1113
* `indicates_convergence(sc)` a boolean whether or not this stopping criterion would indicate
1214
that the algorithm has converged, if it indicates to stop.
13-
* `show(io::IO, scs)` for use in REPL and display within an [`Algorithm`](@ref).
1415
"""
1516
abstract type StoppingCriterion end
1617

@@ -98,7 +99,7 @@ A simple stopping criterion to stop after a maximal number of iterations.
9899
initialize the functor to indicate to stop after `maxIter` iterations.
99100
"""
100101
struct StopAfterIteration <: StoppingCriterion
101-
max_iterations::Int
102+
max_iterations::Int
102103
end
103104

104105
"""
@@ -118,35 +119,44 @@ mutable struct DefaultStoppingCriterionState
118119
DefaultStoppingCriterionState() = new(-1)
119120
end
120121

121-
initialize(::Problem, ::Algorithm, ::State, ::StopAfterIteration) = DefaultStoppingCriterionState()
122-
function initialize!(scs::DefaultStoppingCriterionState, ::Problem, ::Algorithm, ::State, ::StopAfterIteration)
123-
scs.indicated_convergence_at = -1
122+
initialize_state(::Problem, ::Algorithm, ::StopAfterIteration; kwargs...) =
123+
DefaultStoppingCriterionState()
124+
function initialize_state!(
125+
scs::DefaultStoppingCriterionState,
126+
::Problem,
127+
::Algorithm,
128+
::StopAfterIteration;
129+
kwargs...,
130+
)
131+
scs.at_iteration = -1
124132
return scs
125133
end
126134

127-
function (sc::DefaultStoppingCriterionState)(::Problem, ::Algorithm, s::State, sc::StopAfterIteration)
135+
function (scs::DefaultStoppingCriterionState)(
136+
::Problem,
137+
::Algorithm,
138+
s::State,
139+
sc::StopAfterIteration,
140+
)
128141
k = get_iteration(s)
129-
(k == 0) && (sc.at_iteration = -1)
142+
(k == 0) && (scs.at_iteration = -1)
130143
if k >= sc.max_iterations
131-
sc.at_iteration = k
144+
scs.at_iteration = k
132145
return true
133146
end
134147
return false
135148
end
136-
function get_reason(c::StopAfterIteration, scs::DefaultStoppingCriterionState)
137-
if c.at_iteration >= c.max_iterations
138-
return "At iteration $(c.at_iteration) the algorithm reached its maximal number of iterations ($(c.max_iterations)).\n"
149+
function get_reason(sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
150+
if scs.at_iteration >= sc.max_iterations
151+
return "At iteration $(scs.at_iteration) the algorithm reached its maximal number of iterations ($(sc.max_iterations)).\n"
139152
end
140153
return ""
141154
end
142155
indicates_convergence(sc::StopAfterIteration) = false
143-
function get_summary(c::StopAfterIteration)
144-
has_stopped = (c.at_iteration >= 0)
156+
function get_summary(sc::StopAfterIteration, scs::DefaultStoppingCriterionState)
157+
has_stopped = (scs.at_iteration >= 0)
145158
s = has_stopped ? "reached" : "not reached"
146-
return "Max Iteration $(c.max_iterations):\t$s"
147-
end
148-
function show(io::IO, c::StopAfterIteration)
149-
return print(io, "StopAfterIteration($(c.max_iterations))\n $(get_summary(c))")
159+
return "Max Iteration $(sc.max_iterations):\t$s"
150160
end
151161

152162
"""
@@ -159,57 +169,86 @@ for example `Minute(15)`.
159169
# Fields
160170
161171
* `threshold` stores the `Period` after which to stop
162-
* `start` stores the starting time when the algorithm is started, that is a call with `i=0`.
163-
* `time` stores the elapsed time
164-
* `at_iteration` indicates at which iteration (including `i=0`) the stopping criterion
165-
was fulfilled and is `-1` while it is not fulfilled.
166172
167173
# Constructor
168174
169175
StopAfter(t)
170176
171177
initialize the stopping criterion to a `Period t` to stop after.
172178
"""
173-
mutable struct StopAfter <: StoppingCriterion
179+
struct StopAfter <: StoppingCriterion
174180
threshold::Period
175-
start::Nanosecond
176-
time::Nanosecond
177-
at_iteration::Int
178181
function StopAfter(t::Period)
179-
return if value(t) < 0
182+
if value(t) < 0
180183
error("You must provide a positive time period")
181184
else
182-
new(t, Nanosecond(0), Nanosecond(0), -1)
185+
s = new(t)
183186
end
187+
return s
184188
end
185189
end
186-
function (sc::StopAfter)(::Problem, ::Algorithm, s::State)
190+
191+
@doc """
192+
StopAfterTimePeriodState <: StoppingCriterionState
193+
194+
A state for stopping criteria that are based on time measurements,
195+
for example [`StopAfter`](@ref).
196+
197+
* `start` stores the starting time when the algorithm is started, that is a call with `i=0`.
198+
* `time` stores the elapsed time
199+
* `at_iteration` indicates at which iteration (including `i=0`) the stopping criterion
200+
was fulfilled and is `-1` while it is not fulfilled.
201+
202+
"""
203+
mutable struct StopAfterTimePeriodState <: StoppingCriterionState
204+
start::Nanosecond
205+
time::Nanosecond
206+
at_iteration::Int
207+
function StopAfterTimePeriodState()
208+
return new(Nanosecond(0), Nanosecond(0), -1)
209+
end
210+
end
211+
212+
initialize_state(::Problem, ::Algorithm, ::StopAfter; kwargs...) =
213+
StopAfterTimePeriodState()
214+
215+
function initialize_state!(
216+
scs::DefaultStoppingCriterionState,
217+
::Problem,
218+
::Algorithm,
219+
::StopAfter;
220+
kwargs...,
221+
)
222+
scs.start = Nanosecond(0)
223+
scs.time = Nanosecond(0)
224+
scs.at_iteration = -1
225+
return scs
226+
end
227+
228+
function (scs::StopAfterTimePeriodState)(::Problem, ::Algorithm, s::State, sc::StopAfter)
187229
k = get_iteration(s)
188-
if value(sc.start) == 0 || k <= 0 # (re)start timer
189-
sc.at_iteration = -1
190-
sc.start = Nanosecond(time_ns())
191-
sc.time = Nanosecond(0)
230+
if value(scs.start) == 0 || k <= 0 # (re)start timer
231+
scs.at_iteration = -1
232+
scs.start = Nanosecond(time_ns())
233+
scs.time = Nanosecond(0)
192234
else
193-
sc.time = Nanosecond(time_ns()) - sc.start
194-
if k > 0 && (sc.time > Nanosecond(sc.threshold))
195-
sc.at_iteration = k
235+
scs.time = Nanosecond(time_ns()) - scs.start
236+
if k > 0 && (scs.time > Nanosecond(sc.threshold))
237+
scs.at_iteration = k
196238
return true
197239
end
198240
end
199241
return false
200242
end
201-
function get_reason(sc::StopAfter)
202-
if (c.at_iteration >= 0)
203-
return "After iteration $(sc.at_iteration) the algorithm ran for $(floor(c.time, typeof(c.threshold))) (threshold: $(c.threshold)).\n"
243+
function get_reason(sc::StopAfter, scs::StopAfterTimePeriodState)
244+
if (scs.at_iteration >= 0)
245+
return "After iteration $(scs.at_iteration) the algorithm ran for $(floor(scs.time, typeof(sc.threshold))) (threshold: $(sc.threshold)).\n"
204246
end
205247
return ""
206248
end
207-
function get_summary(c::StopAfter)
208-
has_stopped = (c.at_iteration >= 0)
249+
function get_summary(sc::StopAfter, scs::StopAfterTimePeriodState)
250+
has_stopped = (scs.at_iteration >= 0)
209251
s = has_stopped ? "reached" : "not reached"
210-
return "stopped after $(c.threshold):\t$s"
211-
end
212-
indicates_convergence(c::StopAfter) = false
213-
function show(io::IO, sc::StopAfter)
214-
return print(io, "StopAfter($(repr(sc.threshold)))\n $(get_summary(sc))")
252+
return "stopped after $(sc.threshold):\t$s"
215253
end
254+
indicates_convergence(sc::StopAfter) = false

test/newton.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,29 @@ end
2020
mutable struct NewtonState{S} <: State
2121
iteration::Int
2222
iterate::Float64
23-
stopping_criterion::S
23+
stopping_criterion_state::S
2424
end
2525

2626
# Implementing the algorithm
2727
# --------------------------
28-
function initialize_state(::RootFindingProblem, algorithm::NewtonMethod)
29-
return NewtonState(0, 1.0, algorithm.stopping_criterion) # hardcode initial guess to 1.0
28+
function initialize_state(problem::RootFindingProblem, algorithm::NewtonMethod)
29+
scs = initialize_state(problem, algorithm, algorithm.stopping_criterion)
30+
return NewtonState(0, 1.0, scs) # hardcode initial guess to 1.0
3031
end
31-
function initialize_state!(::RootFindingProblem, algorithm::NewtonMethod, state::NewtonState)
32+
function initialize_state!(
33+
state::NewtonState,
34+
problem::RootFindingProblem,
35+
algorithm::NewtonMethod,
36+
)
3237
state.iteration = 0
3338
state.iterate = 1.0
34-
state.stopping_criterion = algorithm.stopping_criterion
39+
initialize_state!(
40+
state.stopping_criterion_state,
41+
problem,
42+
algorithm,
43+
algorithm.stopping_criterion,
44+
)
45+
return state
3546
end
3647

3748
function step!(problem::RootFindingProblem, ::NewtonMethod, state::NewtonState)

0 commit comments

Comments
 (0)