Skip to content

Commit 5f8c363

Browse files
authored
Interface improvements: finalize_state! and deduplicate initialize_state(!) calls (#10)
* allow customization of output * avoid double `initialize_state` call in `solve` * centralize `_solve_body!` code * refactor to use `solve_loop!` * swap `finalize` and `emit_message` order * default output `state.iterate`
1 parent 0879500 commit 5f8c363

6 files changed

Lines changed: 57 additions & 23 deletions

File tree

docs/src/interface.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,14 @@ With these definitions in place you can already run (assuming you also choose a
128128
function heron_sqrt(x; maxiter = 10)
129129
prob = SqrtProblem(x)
130130
alg = HeronAlgorithm(StopAfterIteration(maxiter))
131-
state = solve(prob, alg) # allocates & runs
132-
return state.iterate
131+
return solve(prob, alg) # allocates & runs
133132
end
134133
135134
println("Approximate sqrt: ", heron_sqrt(16.0))
136135
```
137136

137+
Note that [`solve`](@ref) will default to returning `state.iterate`.
138+
If desired, this can be customized by altering [`finalize_state!`](@ref).
138139
We will refine this example with better halting logic and logging shortly.
139140

140141
## Reference: Core interface types & functions

docs/src/logging.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ end
7575
function heron_sqrt(x; stopping_criterion = StopAfterIteration(10))
7676
prob = SqrtProblem(x)
7777
alg = HeronAlgorithm(stopping_criterion)
78-
state = solve(prob, alg) # allocates & runs
79-
return state.iterate
78+
return solve(prob, alg) # allocates & runs
8079
end
8180
nothing # hide
8281
```
@@ -329,7 +328,7 @@ function solve!(problem::Problem, algorithm::Algorithm, state::State; kwargs...)
329328

330329
emit_message(problem, algorithm, state, :Stop)
331330

332-
return state
331+
return finalize_state!(problem, algorithm, state)
333332
end
334333
```
335334

docs/src/stopping_criterion.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ We can again combine everything into a single function, but now make the stoppin
112112
function heron_sqrt(x; stopping_criterion)
113113
prob = SqrtProblem(x)
114114
alg = HeronAlgorithm(stopping_criterion)
115-
state = solve(prob, alg) # allocates & runs
116-
return state.iterate, state.iteration
115+
return solve(prob, alg) # allocates & runs
117116
end
118117
119118
heron_sqrt(2; stopping_criterion = StopAfterIteration(10))

src/AlgorithmsInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ include("logging.jl")
2323
# general interface
2424
export Algorithm, Problem, State
2525
export initialize_state, initialize_state!
26+
export finalize_state!
2627

27-
export step!, solve, solve!
28+
export step!, solve, solve!, solve_loop!
2829

2930
# stopping criteria
3031
export StoppingCriterion, StoppingCriterionState

src/interface/interface.jl

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ function initialize_state! end
1717
@doc "$(_doc_init_state)"
1818
initialize_state!(::Problem, ::Algorithm, ::State; kwargs...)
1919

20+
21+
"""
22+
output = finalize_state!(problem::Problem, algorithm::Algorithm, state::State)
23+
24+
Finalize the solver and decide what values get returned from the [`solve!`](@ref) call.
25+
By default, this is a no-op and returns the `state.iterate`, but this allows for further
26+
customization in other cases, for example to clean up used resources or output other data.
27+
"""
28+
finalize_state!(problem::Problem, algorithm::Algorithm, state::State) = state.iterate
29+
2030
# has to be defined before used in solve but is documented alphabetically after
2131

2232
@doc """
@@ -30,8 +40,21 @@ returns a state.
3040
By default this method continues to call [`solve!`](@ref).
3141
"""
3242
function solve(problem::Problem, algorithm::Algorithm; kwargs...)
43+
# obtain logger once to minimize overhead from accessing ScopedValue
44+
# additionally handle logging initialization to enable stateful LoggingAction
45+
logger = algorithm_logger()
46+
47+
# initialize the state and emit message
3348
state = initialize_state(problem, algorithm; kwargs...)
34-
return solve!(problem, algorithm, state; kwargs...)
49+
emit_message(logger, problem, algorithm, state, :Start)
50+
51+
# main loop
52+
state = solve_loop!(problem, algorithm, state)
53+
54+
# emit message about finished state
55+
emit_message(logger, problem, algorithm, state, :Stop)
56+
57+
return finalize_state!(problem, algorithm, state)
3558
end
3659

3760
@doc """
@@ -46,28 +69,39 @@ function solve!(problem::Problem, algorithm::Algorithm, state::State; kwargs...)
4669
# obtain logger once to minimize overhead from accessing ScopedValue
4770
# additionally handle logging initialization to enable stateful LoggingAction
4871
logger = algorithm_logger()
49-
# initialize_logger(logger, problem, algorithm, state)
5072

5173
# initialize the state and emit message
5274
initialize_state!(problem, algorithm, state; kwargs...)
5375
emit_message(logger, problem, algorithm, state, :Start)
5476

55-
# main body of the algorithm
77+
# main loop
78+
state = solve_loop!(problem, algorithm, state)
79+
80+
# emit message about finished state
81+
emit_message(logger, problem, algorithm, state, :Stop)
82+
83+
return finalize_state!(problem, algorithm, state)
84+
end
85+
86+
"""
87+
solve_loop!(problem::Problem, algorithm::Algorithm, state::State)
88+
89+
Provide the main loop of the iterative `algorithm` for a given `problem` and starting `state`.
90+
91+
This loop consists of:
92+
1. Checking for convergence with [`is_finished!`](@ref)
93+
2. Incrementing the state [`increment!`](@ref)
94+
3. Performing a step [`step!`](@ref)
95+
4. Repeat
96+
"""
97+
function solve_loop!(problem::Problem, algorithm::Algorithm, state::State)
98+
logger = algorithm_logger()
5699
while !is_finished!(problem, algorithm, state)
57-
# logging event between convergence check and algorithm step
58100
emit_message(logger, problem, algorithm, state, :PreStep)
59-
60-
# algorithm step
61101
increment!(state)
62102
step!(problem, algorithm, state)
63-
64-
# logging event between algorithm step and convergence check
65103
emit_message(logger, problem, algorithm, state, :PostStep)
66104
end
67-
68-
# emit message about finished state
69-
emit_message(logger, problem, algorithm, state, :Stop)
70-
71105
return state
72106
end
73107

test/newton.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ end
6060
problem = RootFindingProblem(x -> f(x, a), x -> df(x, a))
6161
algorithm1 = NewtonMethod(StopAfterIteration(8))
6262
solution1 = solve(problem, algorithm1)
63-
@test solution1.iterate sqrt(a)
63+
@test solution1 sqrt(a)
6464
algorithm2 = NewtonMethod(StopAfterIteration(10))
6565
solution2 = solve(problem, algorithm2)
66-
@test solution2.iterate sqrt(a)
67-
@test abs(solution2.iterate - sqrt(a)) < abs(solution1.iterate - sqrt(a))
66+
@test solution2 sqrt(a)
67+
@test abs(solution2 - sqrt(a)) < abs(solution1 - sqrt(a))
6868
end

0 commit comments

Comments
 (0)