Skip to content

Commit 9df2763

Browse files
adapt to Flux v0.16 (#661)
1 parent 62ec930 commit 9df2763

3 files changed

Lines changed: 41 additions & 44 deletions

File tree

DifferentiationInterface/test/Down/Flux/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ LOGGING = get(ENV, "CI", "false") == "false"
1515
test_differentiation(
1616
[
1717
AutoZygote(),
18-
# AutoEnzyme() # TODO: fix
18+
# AutoEnzyme(), # TODO a few scenarios fail
1919
],
2020
DIT.flux_scenarios(Random.MersenneTwister(0));
2121
isapprox=DIT.flux_isapprox,

DifferentiationInterfaceTest/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1616
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1717
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
18+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1920

2021
[weakdeps]
@@ -46,7 +47,7 @@ DifferentiationInterface = "0.6.0"
4647
DocStringExtensions = "0.8,0.9"
4748
ExplicitImports = "1.10.1"
4849
FiniteDifferences = "0.12"
49-
Flux = "0.13,0.14"
50+
Flux = "0.16"
5051
ForwardDiff = "0.10.36"
5152
Functors = "0.4, 0.5"
5253
JET = "0.4 - 0.8, 0.9"
@@ -61,6 +62,7 @@ SparseArrays = "<0.0.1,1"
6162
SparseConnectivityTracer = "0.5.0,0.6"
6263
SparseMatrixColorings = "0.4.9"
6364
StaticArrays = "1.9"
65+
Statistics = "1"
6466
Test = "<0.0.1,1"
6567
Zygote = "0.6"
6668
julia = "1.10"

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ using Flux:
1111
ConvTranspose,
1212
Dense,
1313
GRU,
14+
GRUCell,
1415
LSTM,
16+
LSTMCell,
1517
Maxout,
1618
MeanPool,
1719
RNN,
20+
RNNCell,
1821
SamePad,
1922
Scale,
2023
SkipConnection,
@@ -24,6 +27,7 @@ using Flux:
2427
relu
2528
using Functors: @functor, fmapstructure_with_path, fleaves
2629
using LinearAlgebra
30+
using Statistics: mean
2731
using Random: AbstractRNG, default_rng
2832

2933
#=
@@ -43,31 +47,23 @@ end
4347

4448
function DIT.flux_isapprox(a, b; atol, rtol)
4549
isapprox_results = fmapstructure_with_path(a, b) do kp, x, y
46-
if :state in kp # ignore RNN and LSTM state
50+
if x isa AbstractArray{<:Number}
51+
return isapprox(x, y; atol, rtol)
52+
else # ignore non-arrays
4753
return true
48-
else
49-
if x isa AbstractArray{<:Number}
50-
return isapprox(x, y; atol, rtol)
51-
else # ignore non-arrays
52-
return true
53-
end
5454
end
5555
end
5656
return all(fleaves(isapprox_results))
5757
end
5858

59-
function square_loss(model, x)
60-
Flux.reset!(model)
61-
return sum(abs2, model(x))
62-
end
59+
square_loss(model, x) = mean(abs2, model(x))
6360

64-
function square_loss_iterated(model, x)
65-
Flux.reset!(model)
66-
y = copy(x)
67-
for _ in 1:3
68-
y = model(y)
61+
function square_loss_iterated(cell, x)
62+
y, st = cell(x) # uses default initial state
63+
for _ in 1:2
64+
y, st = cell(x, st)
6965
end
70-
return sum(abs2, y)
66+
return mean(abs2, y)
7167
end
7268

7369
struct SimpleDense{W,B,F}
@@ -132,37 +128,33 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
132128
Maxout(() -> Dense(5 => 4, tanh; init), 3),
133129
randn(rng, Float32, 5, 1)
134130
),
135-
(
136-
RNN(3 => 2; init),
137-
randn(rng, Float32, 3, 2)
138-
),
139-
(
140-
Chain(RNN(3 => 4; init), RNN(4 => 3; init)),
141-
randn(rng, Float32, 3, 2)
131+
(
132+
SkipConnection(Dense(2 => 2; init), vcat),
133+
randn(rng, Float32, 2, 3)
142134
),
143135
(
144-
LSTM(3 => 5; init),
145-
randn(rng, Float32, 3, 2)
136+
Bilinear((2, 2) => 3; init),
137+
randn(rng, Float32, 2, 1)
146138
),
147139
(
148-
Chain(LSTM(3 => 5; init), LSTM(5 => 3; init)),
149-
randn(rng, Float32, 3, 2)
140+
ConvTranspose((3, 3), 3 => 2; stride=2, init),
141+
rand(rng, Float32, 5, 5, 3, 1)
150142
),
151143
(
152-
SkipConnection(Dense(2 => 2; init), vcat),
153-
randn(rng, Float32, 2, 3)
144+
RNN(3 => 4; init_kernel=init, init_recurrent_kernel=init),
145+
randn(rng, Float32, 3, 2, 1)
154146
),
155147
(
156-
Bilinear((2, 2) => 3; init),
157-
randn(rng, Float32, 2, 1)
148+
LSTM(3 => 4; init_kernel=init, init_recurrent_kernel=init),
149+
randn(rng, Float32, 3, 2, 1)
158150
),
159151
(
160-
GRU(3 => 5; init),
161-
randn(rng, Float32, 3, 10)
152+
GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init),
153+
randn(rng, Float32, 3, 2, 1)
162154
),
163155
(
164-
ConvTranspose((3, 3), 3 => 2; stride=2, init),
165-
rand(rng, Float32, 5, 5, 3, 1)
156+
Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)),
157+
randn(rng, Float32, 3, 2, 1)
166158
),
167159
#! format: on
168160
]
@@ -176,29 +168,32 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
176168
push!(scens, scen)
177169
end
178170

179-
# Recurrence
171+
# Recurrent Cells
180172

181173
recurrent_models_and_xs = [
182174
#! format: off
183175
(
184-
RNN(3 => 3; init),
176+
RNNCell(3 => 3; init_kernel=init, init_recurrent_kernel=init),
177+
randn(rng, Float32, 3, 2)
178+
),
179+
(
180+
LSTMCell(3 => 3; init_kernel=init, init_recurrent_kernel=init),
185181
randn(rng, Float32, 3, 2)
186182
),
187183
(
188-
LSTM(3 => 3; init),
184+
GRUCell(3 => 3; init_kernel=init, init_recurrent_kernel=init),
189185
randn(rng, Float32, 3, 2)
190186
),
191187
#! format: on
192188
]
193189

194190
for (model, x) in recurrent_models_and_xs
195191
Flux.trainmode!(model)
196-
g = gradient_finite_differences(square_loss, model, x)
192+
g = gradient_finite_differences(square_loss_iterated, model, x)
197193
scen = DIT.Scenario{:gradient,:out}(
198194
square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g
199195
)
200-
# TODO: figure out why these tests are broken
201-
# push!(scens, scen)
196+
push!(scens, scen)
202197
end
203198

204199
return scens

0 commit comments

Comments
 (0)