diff --git a/docs/src/interface.md b/docs/src/interface.md index 9d4817a..4ec3a27 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -128,13 +128,14 @@ With these definitions in place you can already run (assuming you also choose a function heron_sqrt(x; maxiter = 10) prob = SqrtProblem(x) alg = HeronAlgorithm(StopAfterIteration(maxiter)) - state = solve(prob, alg) # allocates & runs - return state.iterate + return solve(prob, alg) # allocates & runs end println("Approximate sqrt: ", heron_sqrt(16.0)) ``` +Note that [`solve`](@ref) will default to returning `state.iterate`. +If desired, this can be customized by altering [`finalize_state!`](@ref). We will refine this example with better halting logic and logging shortly. ## Reference: Core interface types & functions diff --git a/docs/src/logging.md b/docs/src/logging.md index c167abb..b2b5da1 100644 --- a/docs/src/logging.md +++ b/docs/src/logging.md @@ -75,8 +75,7 @@ end function heron_sqrt(x; stopping_criterion = StopAfterIteration(10)) prob = SqrtProblem(x) alg = HeronAlgorithm(stopping_criterion) - state = solve(prob, alg) # allocates & runs - return state.iterate + return solve(prob, alg) # allocates & runs end nothing # hide ``` @@ -329,7 +328,7 @@ function solve!(problem::Problem, algorithm::Algorithm, state::State; kwargs...) emit_message(problem, algorithm, state, :Stop) - return state + return finalize_state!(problem, algorithm, state) end ``` diff --git a/docs/src/stopping_criterion.md b/docs/src/stopping_criterion.md index f03c3e4..a7e2218 100644 --- a/docs/src/stopping_criterion.md +++ b/docs/src/stopping_criterion.md @@ -112,8 +112,7 @@ We can again combine everything into a single function, but now make the stoppin function heron_sqrt(x; stopping_criterion) prob = SqrtProblem(x) alg = HeronAlgorithm(stopping_criterion) - state = solve(prob, alg) # allocates & runs - return state.iterate, state.iteration + return solve(prob, alg) # allocates & runs end heron_sqrt(2; stopping_criterion = StopAfterIteration(10)) diff --git a/src/AlgorithmsInterface.jl b/src/AlgorithmsInterface.jl index 2a9ed68..0907dbd 100644 --- a/src/AlgorithmsInterface.jl +++ b/src/AlgorithmsInterface.jl @@ -23,8 +23,9 @@ include("logging.jl") # general interface export Algorithm, Problem, State export initialize_state, initialize_state! +export finalize_state! -export step!, solve, solve! +export step!, solve, solve!, solve_loop! # stopping criteria export StoppingCriterion, StoppingCriterionState diff --git a/src/interface/interface.jl b/src/interface/interface.jl index 8640795..f294dce 100644 --- a/src/interface/interface.jl +++ b/src/interface/interface.jl @@ -17,6 +17,16 @@ function initialize_state! end @doc "$(_doc_init_state)" initialize_state!(::Problem, ::Algorithm, ::State; kwargs...) + +""" + output = finalize_state!(problem::Problem, algorithm::Algorithm, state::State) + +Finalize the solver and decide what values get returned from the [`solve!`](@ref) call. +By default, this is a no-op and returns the `state.iterate`, but this allows for further +customization in other cases, for example to clean up used resources or output other data. +""" +finalize_state!(problem::Problem, algorithm::Algorithm, state::State) = state.iterate + # has to be defined before used in solve but is documented alphabetically after @doc """ @@ -30,8 +40,21 @@ returns a state. By default this method continues to call [`solve!`](@ref). """ function solve(problem::Problem, algorithm::Algorithm; kwargs...) + # obtain logger once to minimize overhead from accessing ScopedValue + # additionally handle logging initialization to enable stateful LoggingAction + logger = algorithm_logger() + + # initialize the state and emit message state = initialize_state(problem, algorithm; kwargs...) - return solve!(problem, algorithm, state; kwargs...) + emit_message(logger, problem, algorithm, state, :Start) + + # main loop + state = solve_loop!(problem, algorithm, state) + + # emit message about finished state + emit_message(logger, problem, algorithm, state, :Stop) + + return finalize_state!(problem, algorithm, state) end @doc """ @@ -46,28 +69,39 @@ function solve!(problem::Problem, algorithm::Algorithm, state::State; kwargs...) # obtain logger once to minimize overhead from accessing ScopedValue # additionally handle logging initialization to enable stateful LoggingAction logger = algorithm_logger() - # initialize_logger(logger, problem, algorithm, state) # initialize the state and emit message initialize_state!(problem, algorithm, state; kwargs...) emit_message(logger, problem, algorithm, state, :Start) - # main body of the algorithm + # main loop + state = solve_loop!(problem, algorithm, state) + + # emit message about finished state + emit_message(logger, problem, algorithm, state, :Stop) + + return finalize_state!(problem, algorithm, state) +end + +""" + solve_loop!(problem::Problem, algorithm::Algorithm, state::State) + +Provide the main loop of the iterative `algorithm` for a given `problem` and starting `state`. + +This loop consists of: +1. Checking for convergence with [`is_finished!`](@ref) +2. Incrementing the state [`increment!`](@ref) +3. Performing a step [`step!`](@ref) +4. Repeat +""" +function solve_loop!(problem::Problem, algorithm::Algorithm, state::State) + logger = algorithm_logger() while !is_finished!(problem, algorithm, state) - # logging event between convergence check and algorithm step emit_message(logger, problem, algorithm, state, :PreStep) - - # algorithm step increment!(state) step!(problem, algorithm, state) - - # logging event between algorithm step and convergence check emit_message(logger, problem, algorithm, state, :PostStep) end - - # emit message about finished state - emit_message(logger, problem, algorithm, state, :Stop) - return state end diff --git a/test/newton.jl b/test/newton.jl index 3a038aa..a60cb90 100644 --- a/test/newton.jl +++ b/test/newton.jl @@ -60,9 +60,9 @@ end problem = RootFindingProblem(x -> f(x, a), x -> df(x, a)) algorithm1 = NewtonMethod(StopAfterIteration(8)) solution1 = solve(problem, algorithm1) - @test solution1.iterate ≈ sqrt(a) + @test solution1 ≈ sqrt(a) algorithm2 = NewtonMethod(StopAfterIteration(10)) solution2 = solve(problem, algorithm2) - @test solution2.iterate ≈ sqrt(a) - @test abs(solution2.iterate - sqrt(a)) < abs(solution1.iterate - sqrt(a)) + @test solution2 ≈ sqrt(a) + @test abs(solution2 - sqrt(a)) < abs(solution1 - sqrt(a)) end