-
-
Notifications
You must be signed in to change notification settings - Fork 74
Expand file tree
/
Copy pathsymbol_indexing.jl
More file actions
74 lines (62 loc) · 2.41 KB
/
symbol_indexing.jl
File metadata and controls
74 lines (62 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Test
using Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D
include("../testutils.jl")
@variables x(t)
@parameters τ
@variables RHS(t)
@mtkcompile fol_separate = System([RHS ~ (1 - x) / τ, D(x) ~ RHS], t)
prob = ODEProblem(fol_separate, [x => 0.0, τ => 3.0], (0.0, 10.0))
sol = solve(prob, Tsit5())
sol_new = DiffEqArray(sol.u[1:10],
sol.t[1:10],
sol.prob.p,
sol)
@test sol_new[RHS] ≈ (1 .- sol_new[x]) ./ 3.0
@test sol_new[t] ≈ sol_new.t
@test sol_new[t, 1:5] ≈ sol_new.t[1:5]
@test getp(sol, τ)(sol) == getp(sol_new, τ)(sol_new) == 3.0
@test all(isequal.(variable_symbols(sol), variable_symbols(sol_new)))
@test all(isequal.(variable_symbols(sol), [x]))
@test all(isequal.(all_variable_symbols(sol), all_variable_symbols(sol_new)))
@test all(isequal.(all_variable_symbols(sol), [x, RHS]))
@test all(isequal.(all_symbols(sol), all_symbols(sol_new)))
@test all([any(isequal(sym), all_symbols(sol))
for sym in [x, RHS, τ, t, Initial(x), Initial(RHS)]])
@test sol[solvedvariables] == sol[[x]]
@test sol_new[solvedvariables] == sol_new[[x]]
@test sol[allvariables] == sol[[x, RHS]]
@test sol_new[allvariables] == sol_new[[x, RHS]]
@test_throws Exception sol[τ]
@test_throws Exception sol_new[τ]
gs, = Zygote.gradient(sol) do sol
sum(sol[fol_separate.x])
end
@testset "Symbolic Indexing ADjoint" begin
@test all(all.(isone, gs))
end
# Tables interface
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))
# Two components
@variables y(t)
@parameters α β γ δ
@mtkcompile lv = System([D(x) ~ α * x - β * x * y,
D(y) ~ δ * x * y - γ * x * y], t)
prob = ODEProblem(lv, [x => 1.0, y => 1.0, α => 1.5, β => 1.0, γ => 3.0, δ => 1.0], (0.0, 10.0))
sol = solve(prob, Tsit5())
ts = 0:0.5:10
sol_ts = sol(ts)
@assert sol_ts isa DiffEqArray
test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")],
hcat(ts, Array(sol_ts)'))
# Array variables
using LinearAlgebra
sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
ps = @parameters p[1:3] = [1, 2, 3]
eqs = [collect(D.(x) .~ x)
D(y) ~ norm(collect(x)) * y - x[1]]
@mtkcompile sys = ODESystem(eqs, t, sts, ps)
prob = ODEProblem(sys, [], (0, 1.0))
sol = solve(prob, Tsit5())
@test sol[x .+ [y, 2y, 3y]] ≈ vcat.(getindex.((sol,), [x[1] + y, x[2] + 2y, x[3] + 3y])...)
@test sol[x, :] ≈ sol[x]