Skip to content

Commit 14fb6ce

Browse files
authored
Allow mutation for FastDifferentiation (#144)
* Allow mutation for FastDifferentiation * De-toggle tests
1 parent 069554a commit 14fb6ce

3 files changed

Lines changed: 158 additions & 25 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using DifferentiationInterface:
1010
HVPExtras,
1111
JacobianExtras,
1212
PullbackExtras,
13-
PushforwardExtras
13+
PushforwardExtras,
14+
SecondDerivativeExtras
1415
using FastDifferentiation:
1516
derivative,
1617
hessian,
@@ -30,7 +31,6 @@ const AnyAutoFastDifferentiation = Union{
3031
}
3132

3233
DI.mode(::AnyAutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
33-
DI.supports_mutation(::AnyAutoFastDifferentiation) = DI.MutationNotSupported()
3434

3535
myvec(x::Number) = [x]
3636
myvec(x::AbstractArray) = vec(x)
@@ -39,5 +39,6 @@ issparse(::AutoFastDifferentiation) = false
3939
issparse(::AutoSparseFastDifferentiation) = true
4040

4141
include("allocating.jl")
42+
include("mutating.jl")
4243

4344
end

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/allocating.jl

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
## Pushforward
22

3-
struct FastDifferentiationAllocatingPushforwardExtras{E} <: PushforwardExtras
3+
struct FastDifferentiationAllocatingPushforwardExtras{Y,E} <: PushforwardExtras
4+
y_prototype::Y
45
jvp_exe::E
56
end
67

78
function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x)
9+
y_prototype = f(x)
810
x_var = if x isa Number
911
only(make_variables(:x))
1012
else
@@ -16,7 +18,7 @@ function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x)
1618
y_vec_var = y_var isa Number ? [y_var] : vec(y_var)
1719
jv_vec_var, v_vec_var = jacobian_times_v(y_vec_var, x_vec_var)
1820
jvp_exe = make_function(jv_vec_var, [x_vec_var; v_vec_var]; in_place=false)
19-
return FastDifferentiationAllocatingPushforwardExtras(jvp_exe)
21+
return FastDifferentiationAllocatingPushforwardExtras(y_prototype, jvp_exe)
2022
end
2123

2224
function DI.value_and_pushforward(
@@ -36,6 +38,22 @@ function DI.value_and_pushforward(
3638
end
3739
end
3840

41+
function DI.pushforward(
42+
f,
43+
::AnyAutoFastDifferentiation,
44+
x,
45+
dx,
46+
extras::FastDifferentiationAllocatingPushforwardExtras,
47+
)
48+
v_vec = vcat(myvec(x), myvec(dx))
49+
jv_vec = extras.jvp_exe(v_vec)
50+
if extras.y_prototype isa Number
51+
return only(jv_vec)
52+
else
53+
return reshape(jv_vec, size(extras.y_prototype))
54+
end
55+
end
56+
3957
## Pullback
4058

4159
# TODO: this only fails for scalar -> matrix, not sure why
@@ -82,19 +100,21 @@ end
82100

83101
## Derivative
84102

85-
struct FastDifferentiationAllocatingDerivativeExtras{E} <: DerivativeExtras
103+
struct FastDifferentiationAllocatingDerivativeExtras{Y,E} <: DerivativeExtras
104+
y_prototype::Y
86105
der_exe::E
87106
end
88107

89108
function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x)
109+
y_prototype = f(x)
90110
x_var = only(make_variables(:x))
91111
y_var = f(x_var)
92112

93113
x_vec_var = [x_var]
94114
y_vec_var = y_var isa Number ? [y_var] : vec(y_var)
95115
der_vec_var = derivative(y_vec_var, x_var)
96116
der_exe = make_function(der_vec_var, x_vec_var; in_place=false)
97-
return FastDifferentiationAllocatingDerivativeExtras(der_exe)
117+
return FastDifferentiationAllocatingDerivativeExtras(y_prototype, der_exe)
98118
end
99119

100120
function DI.value_and_derivative(
@@ -128,7 +148,12 @@ function DI.derivative(
128148
x,
129149
extras::FastDifferentiationAllocatingDerivativeExtras,
130150
)
131-
return DI.value_and_derivative(f, backend, x, extras)[2]
151+
der_vec = extras.der_exe([x])
152+
if extras.y_prototype isa Number
153+
return only(der_vec)
154+
else
155+
return reshape(der_vec, size(extras.y_prototype))
156+
end
132157
end
133158

134159
function DI.derivative!!(
@@ -143,20 +168,25 @@ end
143168

144169
## Jacobian
145170

146-
struct FastDifferentiationAllocatingJacobianExtras{E} <: JacobianExtras
171+
struct FastDifferentiationAllocatingJacobianExtras{Y,E} <: JacobianExtras
172+
y_prototype::Y
147173
jac_exe::E
148174
end
149175

150176
function DI.prepare_jacobian(f, backend::AnyAutoFastDifferentiation, x)
151-
x_vec_var = make_variables(:x, size(x)...)
152-
y_vec_var = f(x_vec_var)
177+
y_prototype = f(x)
178+
x_var = make_variables(:x, size(x)...)
179+
y_var = f(x_var)
180+
181+
x_vec_var = vec(x_var)
182+
y_vec_var = vec(y_var)
153183
if issparse(backend)
154-
jac_var = sparse_jacobian(vec(y_vec_var), vec(x_vec_var))
184+
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
155185
else
156-
jac_var = jacobian(vec(y_vec_var), vec(x_vec_var))
186+
jac_var = jacobian(y_vec_var, x_vec_var)
157187
end
158-
jac_exe = make_function(jac_var, vec(x_vec_var); in_place=false)
159-
return FastDifferentiationAllocatingJacobianExtras(jac_exe)
188+
jac_exe = make_function(jac_var, x_vec_var; in_place=false)
189+
return FastDifferentiationAllocatingJacobianExtras(y_prototype, jac_exe)
160190
end
161191

162192
function DI.jacobian(
@@ -199,19 +229,21 @@ end
199229

200230
## Second derivative
201231

202-
struct FastDifferentiationAllocatingSecondDerivativeExtras{E} <: DerivativeExtras
232+
struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E} <: SecondDerivativeExtras
233+
y_prototype::Y
203234
der2_exe::E
204235
end
205236

206237
function DI.prepare_second_derivative(f, ::AnyAutoFastDifferentiation, x)
238+
y_prototype = f(x)
207239
x_var = only(make_variables(:x))
208240
y_var = f(x_var)
209241

210242
x_vec_var = [x_var]
211243
y_vec_var = y_var isa Number ? [y_var] : vec(y_var)
212244
der2_vec_var = derivative(y_vec_var, x_var, x_var)
213245
der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false)
214-
return FastDifferentiationAllocatingSecondDerivativeExtras(der2_exe)
246+
return FastDifferentiationAllocatingSecondDerivativeExtras(y_prototype, der2_exe)
215247
end
216248

217249
function DI.second_derivative(
@@ -220,12 +252,11 @@ function DI.second_derivative(
220252
x,
221253
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
222254
)
223-
y = f(x)
224255
der2_vec = extras.der2_exe([x])
225-
if y isa Number
256+
if extras.y_prototype isa Number
226257
return only(der2_vec)
227258
else
228-
return reshape(der2_vec, size(y))
259+
return reshape(der2_vec, size(extras.y_prototype))
229260
end
230261
end
231262

@@ -253,7 +284,7 @@ function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
253284
end
254285
y_var = f(x_var)
255286

256-
x_vec_var = x_var isa Number ? [x_var] : vec(x_var)
287+
x_vec_var = vec(x_var)
257288
hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var)
258289
hvp_exe = make_function(hv_vec_var, [x_vec_var; v_vec_var]; in_place=false)
259290
return FastDifferentiationHVPExtras(hvp_exe)
@@ -262,11 +293,7 @@ end
262293
function DI.hvp(f, ::AnyAutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras)
263294
v_vec = vcat(myvec(x), myvec(v))
264295
hv_vec = extras.hvp_exe(v_vec)
265-
if x isa Number
266-
return only(hv_vec)
267-
else
268-
return reshape(hv_vec, size(x))
269-
end
296+
return reshape(hv_vec, size(x))
270297
end
271298

272299
function DI.hvp!!(
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
## Pushforward
2+
3+
struct FastDifferentiationMutatingPushforwardExtras{E} <: PushforwardExtras
4+
jvp_exe::E
5+
end
6+
7+
function DI.prepare_pushforward(f!, ::AnyAutoFastDifferentiation, y, x)
8+
x_var = if x isa Number
9+
only(make_variables(:x))
10+
else
11+
make_variables(:x, size(x)...)
12+
end
13+
y_var = make_variables(:y, size(y)...)
14+
f!(y_var, x_var)
15+
16+
x_vec_var = x_var isa Number ? [x_var] : vec(x_var)
17+
y_vec_var = vec(y_var)
18+
jv_vec_var, v_vec_var = jacobian_times_v(y_vec_var, x_vec_var)
19+
jvp_exe = make_function(jv_vec_var, [x_vec_var; v_vec_var]; in_place=false)
20+
return FastDifferentiationMutatingPushforwardExtras(jvp_exe)
21+
end
22+
23+
function DI.value_and_pushforward!!(
24+
f!,
25+
y,
26+
_dy,
27+
::AnyAutoFastDifferentiation,
28+
x,
29+
dx,
30+
extras::FastDifferentiationMutatingPushforwardExtras,
31+
)
32+
f!(y, x)
33+
v_vec = vcat(myvec(x), myvec(dx))
34+
jv_vec = extras.jvp_exe(v_vec)
35+
if y isa Number
36+
return y, only(jv_vec)
37+
else
38+
return y, reshape(jv_vec, size(y))
39+
end
40+
end
41+
42+
## Derivative
43+
44+
struct FastDifferentiationMutatingDerivativeExtras{E} <: DerivativeExtras
45+
der_exe::E
46+
end
47+
48+
function DI.prepare_derivative(f!, ::AnyAutoFastDifferentiation, y, x)
49+
x_var = only(make_variables(:x))
50+
y_var = make_variables(:y, size(y)...)
51+
f!(y_var, x_var)
52+
53+
x_vec_var = [x_var]
54+
y_vec_var = vec(y_var)
55+
der_vec_var = derivative(y_vec_var, x_var)
56+
der_exe = make_function(der_vec_var, x_vec_var; in_place=false)
57+
return FastDifferentiationMutatingDerivativeExtras(der_exe)
58+
end
59+
60+
function DI.value_and_derivative!!(
61+
f!,
62+
y,
63+
_der,
64+
::AnyAutoFastDifferentiation,
65+
x,
66+
extras::FastDifferentiationMutatingDerivativeExtras,
67+
)
68+
f!(y, x)
69+
der_vec = extras.der_exe([x])
70+
return y, reshape(der_vec, size(y))
71+
end
72+
73+
## Jacobian
74+
75+
struct FastDifferentiationMutatingJacobianExtras{E} <: JacobianExtras
76+
jac_exe::E
77+
end
78+
79+
function DI.prepare_jacobian(f!, backend::AnyAutoFastDifferentiation, y, x)
80+
x_var = make_variables(:x, size(x)...)
81+
y_var = make_variables(:y, size(y)...)
82+
f!(y_var, x_var)
83+
84+
x_vec_var = vec(x_var)
85+
y_vec_var = vec(y_var)
86+
if issparse(backend)
87+
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
88+
else
89+
jac_var = jacobian(y_vec_var, x_vec_var)
90+
end
91+
jac_exe = make_function(jac_var, x_vec_var; in_place=false)
92+
return FastDifferentiationMutatingJacobianExtras(jac_exe)
93+
end
94+
95+
function DI.value_and_jacobian!!(
96+
f!,
97+
y,
98+
_jac,
99+
backend::AnyAutoFastDifferentiation,
100+
x,
101+
extras::FastDifferentiationMutatingJacobianExtras,
102+
)
103+
f!(y, x)
104+
return y, extras.jac_exe(vec(x))
105+
end

0 commit comments

Comments
 (0)