Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,14 @@ With these definitions in place you can already run (assuming you also choose a
function heron_sqrt(x; maxiter = 10)
prob = SqrtProblem(x)
alg = HeronAlgorithm(StopAfterIteration(maxiter))
state = solve(prob, alg) # allocates & runs
return state.iterate
return solve(prob, alg) # allocates & runs
end

println("Approximate sqrt: ", heron_sqrt(16.0))
```

Note that [`solve`](@ref) will default to returning `state.iterate`.
If desired, this can be customized by altering [`finalize_state!`](@ref).
We will refine this example with better halting logic and logging shortly.

## Reference: Core interface types & functions
Expand Down
5 changes: 2 additions & 3 deletions docs/src/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ end
function heron_sqrt(x; stopping_criterion = StopAfterIteration(10))
prob = SqrtProblem(x)
alg = HeronAlgorithm(stopping_criterion)
state = solve(prob, alg) # allocates & runs
return state.iterate
return solve(prob, alg) # allocates & runs
end
nothing # hide
```
Expand Down Expand Up @@ -329,7 +328,7 @@ function solve!(problem::Problem, algorithm::Algorithm, state::State; kwargs...)

emit_message(problem, algorithm, state, :Stop)

return state
return finalize_state!(problem, algorithm, state)
end
```

Expand Down
3 changes: 1 addition & 2 deletions docs/src/stopping_criterion.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ We can again combine everything into a single function, but now make the stoppin
function heron_sqrt(x; stopping_criterion)
prob = SqrtProblem(x)
alg = HeronAlgorithm(stopping_criterion)
state = solve(prob, alg) # allocates & runs
return state.iterate, state.iteration
return solve(prob, alg) # allocates & runs
end

heron_sqrt(2; stopping_criterion = StopAfterIteration(10))
Expand Down
3 changes: 2 additions & 1 deletion src/AlgorithmsInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ include("logging.jl")
# general interface
export Algorithm, Problem, State
export initialize_state, initialize_state!
export finalize_state!

export step!, solve, solve!
export step!, solve, solve!, solve_loop!

# stopping criteria
export StoppingCriterion, StoppingCriterionState
Expand Down
58 changes: 46 additions & 12 deletions src/interface/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ function initialize_state! end
@doc "$(_doc_init_state)"
initialize_state!(::Problem, ::Algorithm, ::State; kwargs...)


"""
output = finalize_state!(problem::Problem, algorithm::Algorithm, state::State)

Finalize the solver and decide what values get returned from the [`solve!`](@ref) call.
By default, this is a no-op and returns the `state.iterate`, but this allows for further
customization in other cases, for example to clean up used resources or output other data.
"""
finalize_state!(problem::Problem, algorithm::Algorithm, state::State) = state.iterate

# has to be defined before used in solve but is documented alphabetically after

@doc """
Expand All @@ -30,8 +40,21 @@ returns a state.
By default this method continues to call [`solve!`](@ref).
"""
function solve(problem::Problem, algorithm::Algorithm; kwargs...)
# obtain logger once to minimize overhead from accessing ScopedValue
# additionally handle logging initialization to enable stateful LoggingAction
logger = algorithm_logger()

# initialize the state and emit message
state = initialize_state(problem, algorithm; kwargs...)
return solve!(problem, algorithm, state; kwargs...)
emit_message(logger, problem, algorithm, state, :Start)

# main loop
state = solve_loop!(problem, algorithm, state)

# emit message about finished state
emit_message(logger, problem, algorithm, state, :Stop)

return finalize_state!(problem, algorithm, state)
end

@doc """
Expand All @@ -46,28 +69,39 @@ function solve!(problem::Problem, algorithm::Algorithm, state::State; kwargs...)
# obtain logger once to minimize overhead from accessing ScopedValue
# additionally handle logging initialization to enable stateful LoggingAction
logger = algorithm_logger()
# initialize_logger(logger, problem, algorithm, state)

# initialize the state and emit message
initialize_state!(problem, algorithm, state; kwargs...)
emit_message(logger, problem, algorithm, state, :Start)

# main body of the algorithm
# main loop
state = solve_loop!(problem, algorithm, state)

# emit message about finished state
emit_message(logger, problem, algorithm, state, :Stop)

return finalize_state!(problem, algorithm, state)
end

"""
solve_loop!(problem::Problem, algorithm::Algorithm, state::State)

Provide the main loop of the iterative `algorithm` for a given `problem` and starting `state`.

This loop consists of:
1. Checking for convergence with [`is_finished!`](@ref)
2. Incrementing the state [`increment!`](@ref)
3. Performing a step [`step!`](@ref)
4. Repeat
"""
function solve_loop!(problem::Problem, algorithm::Algorithm, state::State)
logger = algorithm_logger()
while !is_finished!(problem, algorithm, state)
# logging event between convergence check and algorithm step
emit_message(logger, problem, algorithm, state, :PreStep)

# algorithm step
increment!(state)
step!(problem, algorithm, state)

# logging event between algorithm step and convergence check
emit_message(logger, problem, algorithm, state, :PostStep)
end

# emit message about finished state
emit_message(logger, problem, algorithm, state, :Stop)

return state
end

Expand Down
6 changes: 3 additions & 3 deletions test/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ end
problem = RootFindingProblem(x -> f(x, a), x -> df(x, a))
algorithm1 = NewtonMethod(StopAfterIteration(8))
solution1 = solve(problem, algorithm1)
@test solution1.iterate ≈ sqrt(a)
@test solution1 ≈ sqrt(a)
algorithm2 = NewtonMethod(StopAfterIteration(10))
solution2 = solve(problem, algorithm2)
@test solution2.iterate ≈ sqrt(a)
@test abs(solution2.iterate - sqrt(a)) < abs(solution1.iterate - sqrt(a))
@test solution2 ≈ sqrt(a)
@test abs(solution2 - sqrt(a)) < abs(solution1 - sqrt(a))
end
Loading