|
| 1 | +# Newton's method for finding roots of a function |
| 2 | +# wrapped as a very simple iterative algorithm |
| 3 | + |
| 4 | +using AlgorithmsInterface |
| 5 | +import AlgorithmsInterface: initialize_state, initialize_state!, is_finished, solve!, step! |
| 6 | +using Test |
| 7 | + |
| 8 | +# Defining the structs |
| 9 | +# ------------------ |
| 10 | +struct RootFindingProblem <: Problem |
| 11 | + f::Function |
| 12 | + df::Function |
| 13 | +end |
| 14 | + |
| 15 | +struct NewtonMethod{S} <: Algorithm |
| 16 | + stopping_criterion::S |
| 17 | + # TODO: logging settings? stopping criterium initialization? |
| 18 | +end |
| 19 | + |
| 20 | +mutable struct NewtonState{S} <: State |
| 21 | + iteration::Int |
| 22 | + iterate::Float64 |
| 23 | + stopping_criterion::S |
| 24 | +end |
| 25 | + |
| 26 | +# Implementing the algorithm |
| 27 | +# -------------------------- |
| 28 | +function initialize_state(::RootFindingProblem, algorithm::NewtonMethod) |
| 29 | + return NewtonState(0, 1.0, algorithm.stopping_criterion) # hardcode initial guess to 1.0 |
| 30 | +end |
| 31 | +function initialize_state!(::RootFindingProblem, algorithm::NewtonMethod, state::NewtonState) |
| 32 | + state.iteration = 0 |
| 33 | + state.iterate = 1.0 |
| 34 | + state.stopping_criterion = algorithm.stopping_criterion |
| 35 | +end |
| 36 | + |
| 37 | +function step!(problem::RootFindingProblem, ::NewtonMethod, state::NewtonState) |
| 38 | + state.iterate -= problem.f(state.iterate) / problem.df(state.iterate) |
| 39 | + return state |
| 40 | +end |
| 41 | + |
| 42 | +# Testing the algorithm |
| 43 | +# --------------------- |
| 44 | +@testset "Babylonian square roots" begin |
| 45 | + f(x, a) = x^2 - a |
| 46 | + df(x, a) = 2x |
| 47 | + |
| 48 | + a = 612.0 |
| 49 | + problem = RootFindingProblem(x -> f(x, a), x -> df(x, a)) |
| 50 | + algorithm1 = NewtonMethod(StopAfterIteration(8)) |
| 51 | + solution1 = solve(problem, algorithm1) |
| 52 | + @test solution1.iterate ≈ sqrt(a) |
| 53 | + algorithm2 = NewtonMethod(StopAfterIteration(10)) |
| 54 | + solution2 = solve(problem, algorithm2) |
| 55 | + @test solution2.iterate ≈ sqrt(a) |
| 56 | + @test abs(solution2.iterate - sqrt(a)) < abs(solution1.iterate - sqrt(a)) |
| 57 | +end |
0 commit comments