|
| 1 | +using RecursiveArrayTools, Zygote, ForwardDiff, Test |
| 2 | +using SciMLBase |
| 3 | + |
| 4 | +function loss(x) |
| 5 | + return sum(abs2, Array(VectorOfArray([x .* i for i in 1:5]))) |
| 6 | +end |
| 7 | + |
| 8 | +function loss2(x) |
| 9 | + return sum(abs2, Array(DiffEqArray([x .* i for i in 1:5], 1:5))) |
| 10 | +end |
| 11 | + |
| 12 | +function loss3(x) |
| 13 | + y = VectorOfArray([x .* i for i in 1:5]) |
| 14 | + tmp = 0.0 |
| 15 | + for i in 1:5, j in 1:5 |
| 16 | + |
| 17 | + tmp += y[i, j] |
| 18 | + end |
| 19 | + return tmp |
| 20 | +end |
| 21 | + |
| 22 | +function loss4(x) |
| 23 | + y = DiffEqArray([x .* i for i in 1:5], 1:5) |
| 24 | + tmp = 0.0 |
| 25 | + for i in 1:5, j in 1:5 |
| 26 | + |
| 27 | + tmp += y[i, j] |
| 28 | + end |
| 29 | + return tmp |
| 30 | +end |
| 31 | + |
| 32 | +function loss5(x) |
| 33 | + return sum(abs2, Array(ArrayPartition([x .* i for i in 1:5]...))) |
| 34 | +end |
| 35 | + |
| 36 | +function loss6(x) |
| 37 | + _x = ArrayPartition([x .* i for i in 1:5]...) |
| 38 | + _prob = ODEProblem((u, p, t) -> u, _x, (0, 1)) |
| 39 | + return sum(abs2, Array(_prob.u0)) |
| 40 | +end |
| 41 | + |
| 42 | +function loss7(x) |
| 43 | + _x = VectorOfArray([x .* i for i in 1:5]) |
| 44 | + return sum(abs2, _x .- 1) |
| 45 | +end |
| 46 | + |
| 47 | +# use a bunch of broadcasts to test all the adjoints |
| 48 | +function loss8(x) |
| 49 | + _x = VectorOfArray([x .* i for i in 1:5]) |
| 50 | + res = copy(_x) |
| 51 | + res = res .+ _x |
| 52 | + res = res .+ 1 |
| 53 | + res = res .* _x |
| 54 | + res = res .* 2.0 |
| 55 | + res = res .* res |
| 56 | + res = res ./ 2.0 |
| 57 | + res = res ./ _x |
| 58 | + res = 3.0 .- res |
| 59 | + res = .-res |
| 60 | + res = identity.(Base.literal_pow.(^, res, Val(2))) |
| 61 | + res = tanh.(res) |
| 62 | + res = res .+ im .* res |
| 63 | + res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res) |
| 64 | + return sum(abs2, res) |
| 65 | +end |
| 66 | + |
| 67 | +function loss9(x) |
| 68 | + return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5]) |
| 69 | +end |
| 70 | + |
| 71 | +function loss10(x) |
| 72 | + voa = VectorOfArray([i * x for i in 1:5]) |
| 73 | + return sum(view(voa, 2:4, 3:5)) |
| 74 | +end |
| 75 | + |
| 76 | +function loss11(x) |
| 77 | + voa = VectorOfArray([i * x for i in 1:5]) |
| 78 | + return sum(view(voa, :, :)) |
| 79 | +end |
| 80 | + |
| 81 | +x = float.(6:10) |
| 82 | +loss(x) |
| 83 | +@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x) |
| 84 | +@test Zygote.gradient(loss2, x)[1] == ForwardDiff.gradient(loss2, x) |
| 85 | +@test Zygote.gradient(loss3, x)[1] == ForwardDiff.gradient(loss3, x) |
| 86 | +@test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x) |
| 87 | +@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x) |
| 88 | +@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x) |
| 89 | +@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x) |
| 90 | +@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x) |
| 91 | +@test ForwardDiff.derivative(loss9, 0.0) == |
| 92 | + VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) |
| 93 | +@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x) |
| 94 | +@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x) |
| 95 | + |
| 96 | +voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3)) |
| 97 | +voa_gs, = Zygote.gradient(voa) do x |
| 98 | + sum(sum.(x.u)) |
| 99 | +end |
| 100 | +@test voa_gs isa RecursiveArrayTools.VectorOfArray |
0 commit comments