Skip to content

Commit 1ad31d1

Browse files
committed
Implement stiff solver
1 parent cc2beea commit 1ad31d1

7 files changed

Lines changed: 150 additions & 18 deletions

File tree

src/coroutines.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
type Coroutine <: DiscreteProcess
1+
mutable struct Coroutine <: DiscreteProcess
22
bev :: BaseEvent
33
fsm :: FiniteStateMachine
44
target :: AbstractEvent

src/odes/QSS.jl

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ struct QSS{T} <: Integrator
1111
function QSS{T}(model::Model, t::Float64, x₀::Vector{Float64}, p::Vector{Float64};
1212
order::Number=4, Δrel::Float64=1e-6, Δabs::Float64=1e-6) where T
1313
qss = new(UInt8(order), model, p, Vector{Taylor1{Float64}}(), Vector{Float64}(), Δrel, Δabs)
14-
t₀ = t + Taylor1(Float64, order)
14+
t₀ = t + Taylor1(Float64, order + 1)
1515
for q₀ in x₀
1616
push!(qss.t, t)
17-
push!(qss.q, q₀ + Taylor1(zeros(Float64, order+1)))
17+
push!(qss.q, q₀ + Taylor1(zeros(Float64, order + 1)))
1818
end
1919
for i in 1:order-1
2020
q = deepcopy(qss.q)
@@ -33,7 +33,7 @@ function Continuous(model::Model, env::Environment, x₀::Vector{Float64}, p::Ve
3333
end
3434

3535
function initial_values(qss::QSS, t::Float64)
36-
t₀ = t + Taylor1(Float64, qss.order+1)
36+
t₀ = t + Taylor1(Float64, qss.order + 1)
3737
x₀ = Vector{Taylor1{Float64}}()
3838
for (i, f) in enumerate(qss.model.f)
3939
push!(x₀, integrate(f(t₀, qss.q, qss.p), qss.q[i][1]))
@@ -45,10 +45,10 @@ function step(var::Variable, cont::Continuous, qss::QSS)
4545
t = now(environment(var))
4646
n = length(qss.model.f)
4747
i = var.id
48-
t₀ = t + Taylor1(Float64, qss.order+1)
48+
t₀ = t + Taylor1(Float64, qss.order + 1)
4949
x₀ = advance_time(var, t)
5050
update_quantized_state(qss, var, t)
51-
Δt = compute_next_time(var.x, max(qss.Δrel*x₀, qss.Δabs))
51+
Δt = compute_next_time(qss, var)
5252
reset(var)
5353
schedule(var, Δt)
5454
for j in filter(j->qss.model.deps[j,i], 1:n)
@@ -60,7 +60,7 @@ function step(var::Variable, cont::Continuous, qss::QSS)
6060
advance_time(qss, k, t)
6161
end
6262
dep.x = integrate(qss.model.f[j](t₀, qss.q, qss.p), x₀)
63-
Δt = recompute_next_time(qss, dep.x, qss.q[j], max(qss.Δrel*x₀, qss.Δabs))
63+
Δt = recompute_next_time(qss, dep)
6464
reset(dep)
6565
schedule(dep, Δt)
6666
end
@@ -79,19 +79,82 @@ function update_quantized_state(qss::QSS{non_stiff}, var::Variable, t::Float64)
7979
qss.t[i] = t
8080
end
8181

82-
function update_quantized_state(qss::QSS{stiff}, vars::Vector{Variable}, i::UInt, t::Float64)
82+
function update_quantized_state(qss::QSS{stiff}, var::Variable, t::Float64)
83+
i = var.id
84+
t₀ = t + Taylor1(Float64, qss.order + 1)
85+
x₀ = evaluate(var.x)
86+
Δq = max(qss.Δrel*x₀, qss.Δabs)
8387
for (j, istrue) in enumerate(qss.model.deps[i, :])
8488
istrue && advance_time(qss, j, t)
8589
end
86-
q₋ = deepcopy(qss.q)
90+
= deepcopy(qss.q)
91+
q̲[i] = Taylor1(zeros(order+1))+x₀-Δq
92+
= integrate(qss.model.f[i](t₀, q̲, qss.p), x₀)
93+
for k in 1:order-1
94+
q̲[i] =-Δq
95+
= integrate(qss.model.f[i](t₀, q̲, qss.p), x₀)
96+
end
97+
q̲[i] =-Δq
98+
q̲[i][end] = 0.0
99+
= deepcopy(qss.q)
100+
q̅[i] = Taylor1(zeros(order+1))+x₀+Δq
101+
= integrate(qss.model.f[i](t₀, q̅, qss.p), x₀)
102+
for k in 1:order-1
103+
q̅[i] =+Δq
104+
= integrate(qss.model.f[i](t₀, q̅, qss.p), x₀)
105+
end
106+
q̅[i] =+Δq
107+
q̅[i][end] = 0.0
108+
if x̲[end] * x̅[end] > 0.0
109+
if x̅[end] > 0.0
110+
var.x = deepcopy(x̅)
111+
q = deepcopy(q̅)
112+
else
113+
var.x = deepcopy(x̲)
114+
q = deepcopy(q̲)
115+
end
116+
else
117+
= brent(nth_derivative, x₀-Δq, x₀+Δq, qss, i; xtol=min(Δq/100, 1e-7))
118+
qss.q[i] =+ Taylor1(zeros(Float64, order + 1))
119+
var.x = integrate(qss.model.f[i](t₀, qss.q, qss.p), x₀)
120+
for k in 1:order-1
121+
qss.q[i] = deepcopy(var.x)
122+
qss.q[i][1] =
123+
var.x = integrate(qss.model.f[i](t₀, qss.q, qss.p), x₀)
124+
end
125+
qss.q[i][end] = 0.0
126+
end
87127
end
88128

89-
function compute_next_time(x::Taylor1, Δq::Float64)
90-
(abs(Δq/x[end]))^(1.0/x.order)
129+
function nth_derivative(q₀::Float64, qss::QSS{stiff}, i::UInt)
130+
t₀ = t + Taylor1(Float64, qss.order + 1)
131+
q = deepcopy(qss.q)
132+
q[i][2:end] = 0.0
133+
q[i][1] = q₀
134+
q[i] = integrate(qss.model.f[i](t₀, q, qss.p), q₀)
135+
for k in 1:order-1
136+
q[i] = integrate(qss.model.f[i](t₀, q, qss.p), q₀)
137+
end
138+
q[i][end]
139+
end
140+
141+
function compute_next_time(qss::QSS{non_stiff}, var::Variable)
142+
x₀ = evaluate(var.x)
143+
Δq = max(qss.Δrel*x₀, qss.Δabs)
144+
(abs(Δq/var.x[end]))^(1.0/qss.order)
91145
end
92146

93-
function recompute_next_time(::QSS{non_stiff}, x::Taylor1{Float64}, q::Taylor1{Float64}, Δq::Float64)
94-
p = (x-q).coeffs
147+
function compute_next_time(qss::QSS{stiff}, var::Variable)
148+
x₀ = evaluate(var.x)
149+
Δq = max(qss.Δrel*x₀, qss.Δabs)
150+
(abs(Δq/var.x[end]))^(1.0/qss.order)
151+
end
152+
153+
function recompute_next_time(qss::QSS{non_stiff}, var::Variable)
154+
i = var.id
155+
x₀ = evaluate(var.x)
156+
Δq = max(qss.Δrel*x₀, qss.Δabs)
157+
p = (var.x-qss.q[i]).coeffs
95158
p[1] -= Δq
96159
neg = roots(p)
97160
p[1] += 2Δq

src/odes/roots.jl

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,78 @@ function roots(coeff::Vector{Float64}) :: Vector{Complex{Float64}}
137137
mat[2:n-1, 1:n-2] = eye(Float64, n-2)
138138
mat[:, n-1] = - coeff[1:n-1] / coeff[n]
139139
res[1:n-1] = eigvals(mat)
140-
#res[1:n-1] = roots(Poly(coeff))
141140
end
142141
end
143142
end
144143
res
145144
end
145+
146+
function brent(f::Function, x0::Number, x1::Number, args...;
147+
xtol::AbstractFloat=1e-7, ytol=2eps(Float64),
148+
maxiter::Integer=50)
149+
EPS = eps(Float64)
150+
y0 = f(x0,args...)
151+
y1 = f(x1,args...)
152+
if abs(y0) < abs(y1)
153+
# Swap lower and upper bounds.
154+
x0, x1 = x1, x0
155+
y0, y1 = y1, y0
156+
end
157+
x2 = x0
158+
y2 = y0
159+
x3 = x2
160+
bisection = true
161+
for _ in 1:maxiter
162+
# x-tolerance.
163+
if abs(x1-x0) < xtol
164+
return x1
165+
end
166+
167+
# Use inverse quadratic interpolation if f(x0)!=f(x1)!=f(x2)
168+
# and linear interpolation (secant method) otherwise.
169+
if abs(y0-y2) > ytol && abs(y1-y2) > ytol
170+
x = x0*y1*y2/((y0-y1)*(y0-y2)) +
171+
x1*y0*y2/((y1-y0)*(y1-y2)) +
172+
x2*y0*y1/((y2-y0)*(y2-y1))
173+
else
174+
x = x1 - y1 * (x1-x0)/(y1-y0)
175+
end
176+
177+
# Use bisection method if satisfies the conditions.
178+
delta = abs(2EPS*abs(x1))
179+
min1 = abs(x-x1)
180+
min2 = abs(x1-x2)
181+
min3 = abs(x2-x3)
182+
if (x < (3x0+x1)/4 && x > x1) ||
183+
(bisection && min1 >= min2/2) ||
184+
(!bisection && min1 >= min3/2) ||
185+
(bisection && min2 < delta) ||
186+
(!bisection && min3 < delta)
187+
x = (x0+x1)/2
188+
bisection = true
189+
else
190+
bisection = false
191+
end
192+
193+
y = f(x,args...)
194+
# y-tolerance.
195+
if abs(y) < ytol
196+
return x
197+
end
198+
x3 = x2
199+
x2 = x1
200+
if sign(y0) != sign(y)
201+
x1 = x
202+
y1 = y
203+
else
204+
x0 = x
205+
y0 = y
206+
end
207+
if abs(y0) < abs(y1)
208+
# Swap lower and upper bounds.
209+
x0, x1 = x1, x0
210+
y0, y1 = y1, y0
211+
end
212+
end
213+
error("Max iteration exceeded")
214+
end

src/processes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
type Process <: DiscreteProcess
1+
mutable struct Process <: DiscreteProcess
22
bev :: BaseEvent
33
task :: Task
44
target :: AbstractEvent

test/continuous.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727
sim = Simulation()
2828
cont = @continuous diffeq(sim, [0.0, 20.0], [2020.0, 0.0]; stiff=false, order=4, Δrel=1e-16, Δabs=1e-6)
2929
@process report(sim, cont)
30-
@time run(sim, 500)
30+
@time run(sim, 2000)
3131

3232
@model function bouncing_ball(t, x, p, dx)
3333
g = 9.81

test/events.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using SimJulia
22

3-
type TestException <: Exception end
3+
struct TestException <: Exception end
44

55
function test_callback_event(ev::Event)
66
println("Hi $ev has value $(value(ev))")

test/simulations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using SimJulia
22

3-
type TestException <: Exception end
3+
struct TestException <: Exception end
44

55
function test_callback(ev::AbstractEvent)
66
println("Hi I timed out at $(now(environment(ev)))")

0 commit comments

Comments
 (0)