Skip to content

Commit 2dab83b

Browse files
committed
Add simple test
1 parent 46f4ba6 commit 2dab83b

1 file changed

Lines changed: 57 additions & 0 deletions

File tree

test/newton.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Newton's method for finding roots of a function
2+
# wrapped as a very simple iterative algorithm
3+
4+
using AlgorithmsInterface
5+
import AlgorithmsInterface: initialize_state, initialize_state!, is_finished, solve!, step!
6+
using Test
7+
8+
# Defining the structs
9+
# ------------------
10+
struct RootFindingProblem <: Problem
11+
f::Function
12+
df::Function
13+
end
14+
15+
struct NewtonMethod{S} <: Algorithm
16+
stopping_criterion::S
17+
# TODO: logging settings? stopping criterium initialization?
18+
end
19+
20+
mutable struct NewtonState{S} <: State
21+
iteration::Int
22+
iterate::Float64
23+
stopping_criterion::S
24+
end
25+
26+
# Implementing the algorithm
27+
# --------------------------
28+
function initialize_state(::RootFindingProblem, algorithm::NewtonMethod)
29+
return NewtonState(0, 1.0, algorithm.stopping_criterion) # hardcode initial guess to 1.0
30+
end
31+
function initialize_state!(::RootFindingProblem, algorithm::NewtonMethod, state::NewtonState)
32+
state.iteration = 0
33+
state.iterate = 1.0
34+
state.stopping_criterion = algorithm.stopping_criterion
35+
end
36+
37+
function step!(problem::RootFindingProblem, ::NewtonMethod, state::NewtonState)
38+
state.iterate -= problem.f(state.iterate) / problem.df(state.iterate)
39+
return state
40+
end
41+
42+
# Testing the algorithm
43+
# ---------------------
44+
@testset "Babylonian square roots" begin
45+
f(x, a) = x^2 - a
46+
df(x, a) = 2x
47+
48+
a = 612.0
49+
problem = RootFindingProblem(x -> f(x, a), x -> df(x, a))
50+
algorithm1 = NewtonMethod(StopAfterIteration(8))
51+
solution1 = solve(problem, algorithm1)
52+
@test solution1.iterate sqrt(a)
53+
algorithm2 = NewtonMethod(StopAfterIteration(10))
54+
solution2 = solve(problem, algorithm2)
55+
@test solution2.iterate sqrt(a)
56+
@test abs(solution2.iterate - sqrt(a)) < abs(solution1.iterate - sqrt(a))
57+
end

0 commit comments

Comments
 (0)