You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl
+37-42Lines changed: 37 additions & 42 deletions
Original file line number
Diff line number
Diff line change
@@ -11,10 +11,13 @@ using Flux:
11
11
ConvTranspose,
12
12
Dense,
13
13
GRU,
14
+
GRUCell,
14
15
LSTM,
16
+
LSTMCell,
15
17
Maxout,
16
18
MeanPool,
17
19
RNN,
20
+
RNNCell,
18
21
SamePad,
19
22
Scale,
20
23
SkipConnection,
@@ -24,6 +27,7 @@ using Flux:
24
27
relu
25
28
using Functors:@functor, fmapstructure_with_path, fleaves
26
29
using LinearAlgebra
30
+
using Statistics: mean
27
31
using Random: AbstractRNG, default_rng
28
32
29
33
#=
@@ -43,31 +47,23 @@ end
43
47
44
48
function DIT.flux_isapprox(a, b; atol, rtol)
45
49
isapprox_results =fmapstructure_with_path(a, b) do kp, x, y
46
-
if:statein kp # ignore RNN and LSTM state
50
+
if x isa AbstractArray{<:Number}
51
+
returnisapprox(x, y; atol, rtol)
52
+
else# ignore non-arrays
47
53
returntrue
48
-
else
49
-
if x isa AbstractArray{<:Number}
50
-
returnisapprox(x, y; atol, rtol)
51
-
else# ignore non-arrays
52
-
returntrue
53
-
end
54
54
end
55
55
end
56
56
returnall(fleaves(isapprox_results))
57
57
end
58
58
59
-
functionsquare_loss(model, x)
60
-
Flux.reset!(model)
61
-
returnsum(abs2, model(x))
62
-
end
59
+
square_loss(model, x) =mean(abs2, model(x))
63
60
64
-
functionsquare_loss_iterated(model, x)
65
-
Flux.reset!(model)
66
-
y =copy(x)
67
-
for _ in1:3
68
-
y =model(y)
61
+
functionsquare_loss_iterated(cell, x)
62
+
y, st =cell(x) # uses default initial state
63
+
for _ in1:2
64
+
y, st =cell(x, st)
69
65
end
70
-
returnsum(abs2, y)
66
+
returnmean(abs2, y)
71
67
end
72
68
73
69
struct SimpleDense{W,B,F}
@@ -132,37 +128,33 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
0 commit comments