Skip to content

Commit 50e2a52

Browse files
authored
Right preparation for second order closures (#122)
* Right preparation for second order closures * Fix
1 parent 21a96ea commit 50e2a52

10 files changed

Lines changed: 172 additions & 54 deletions

File tree

ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ import DifferentiationInterface as DI
66
using FastDifferentiation:
77
derivative,
88
hessian,
9+
hessian_times_v,
910
jacobian,
1011
jacobian_times_v,
1112
jacobian_transpose_v,
1213
make_function,
13-
make_variables
14+
make_variables,
15+
sparse_hessian,
16+
sparse_jacobian
1417
using LinearAlgebra: dot
1518
using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction
1619

@@ -24,6 +27,9 @@ DI.supports_mutation(::AnyAutoFastDifferentiation) = DI.MutationNotSupported()
2427
myvec(x::Number) = [x]
2528
myvec(x::AbstractArray) = vec(x)
2629

30+
issparse(::AutoFastDifferentiation) = false
31+
issparse(::AutoSparseFastDifferentiation) = true
32+
2733
include("allocating.jl")
2834

2935
end

ext/DifferentiationInterfaceFastDifferentiationExt/allocating.jl

Lines changed: 103 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,43 @@ end
6464
6565
=#
6666

67+
## Derivative
68+
69+
function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x)
70+
x_var = only(make_variables(:x))
71+
y_var = f(x_var)
72+
73+
x_vec_var = [x_var]
74+
y_vec_var = y_var isa Number ? [y_var] : vec(y_var)
75+
der_vec_var = derivative(y_vec_var, x_var)
76+
der_exe = make_function(der_vec_var, x_vec_var; in_place=false)
77+
return der_exe
78+
end
79+
80+
function DI.value_and_derivative(f, ::AnyAutoFastDifferentiation, x, der_exe)
81+
y = f(x)
82+
der_vec = der_exe([x])
83+
if y isa Number
84+
return y, only(der_vec)
85+
else
86+
return y, reshape(der_vec, size(y))
87+
end
88+
end
89+
90+
function DI.value_and_derivative!!(f, der, backend::AnyAutoFastDifferentiation, x, der_exe)
91+
return DI.value_and_derivative(f, backend, x, der_exe)
92+
end
93+
6794
## Jacobian
6895

69-
function DI.prepare_jacobian(f, ::AnyAutoFastDifferentiation, x)
96+
function DI.prepare_jacobian(f, backend::AnyAutoFastDifferentiation, x)
7097
x_vec_var = make_variables(:x, size(x)...)
7198
y_vec_var = f(x_vec_var)
72-
jac_var = jacobian(vec(y_vec_var), vec(x_vec_var))
99+
if issparse(backend)
100+
jac_var = sparse_jacobian(vec(y_vec_var), vec(x_vec_var))
101+
else
102+
jac_var = jacobian(vec(y_vec_var), vec(x_vec_var))
103+
end
73104
jac_exe = make_function(jac_var, vec(x_vec_var); in_place=false)
74105
return jac_exe
75106
end
@@ -80,24 +111,85 @@ function DI.jacobian(
80111
return jac_exe(vec(x))
81112
end
82113

83-
function DI.value_and_jacobian(f, backend, x, extras)
84-
return f(x), DI.jacobian(f, backend, x, extras)
114+
function DI.value_and_jacobian(f, backend::AnyAutoFastDifferentiation, x, jac_exe)
115+
return f(x), DI.jacobian(f, backend, x, jac_exe)
116+
end
117+
118+
function DI.jacobian!!(f, jac, backend::AnyAutoFastDifferentiation, x, jac_exe)
119+
return DI.jacobian(f, backend, x, jac_exe)
120+
end
121+
122+
function DI.value_and_jacobian!!(f, jac, backend::AnyAutoFastDifferentiation, x, jac_exe)
123+
return DI.value_and_jacobian(f, backend, x, jac_exe)
124+
end
125+
126+
## Second derivative
127+
128+
function DI.prepare_second_derivative(f, ::AnyAutoFastDifferentiation, x)
129+
x_var = only(make_variables(:x))
130+
y_var = f(x_var)
131+
132+
x_vec_var = [x_var]
133+
y_vec_var = y_var isa Number ? [y_var] : vec(y_var)
134+
der2_vec_var = derivative(y_vec_var, x_var, x_var)
135+
der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false)
136+
return der2_exe
137+
end
138+
139+
function DI.second_derivative(f, ::AnyAutoFastDifferentiation, x, der2_exe)
140+
y = f(x)
141+
der2_vec = der2_exe([x])
142+
if y isa Number
143+
return only(der2_vec)
144+
else
145+
return reshape(der2_vec, size(y))
146+
end
147+
end
148+
149+
function DI.second_derivative!!(f, der2, backend::AnyAutoFastDifferentiation, x, der2_exe)
150+
return DI.second_derivative(f, backend, x, der2_exe)
151+
end
152+
153+
## HVP
154+
155+
function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x)
156+
x_var = if x isa Number
157+
only(make_variables(:x))
158+
else
159+
make_variables(:x, size(x)...)
160+
end
161+
y_var = f(x_var)
162+
163+
x_vec_var = x_var isa Number ? [x_var] : vec(x_var)
164+
hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var)
165+
hvp_exe = make_function(hv_vec_var, [x_vec_var; v_vec_var]; in_place=false)
166+
return hvp_exe
85167
end
86168

87-
function DI.jacobian!!(f, backend::AnyAutoFastDifferentiation, x, extras)
88-
return DI.jacobian(f, backend, x, extras)
169+
function DI.hvp(f, ::AnyAutoFastDifferentiation, x, v, hvp_exe::RuntimeGeneratedFunction)
170+
v_vec = vcat(myvec(x), myvec(v))
171+
hv_vec = hvp_exe(v_vec)
172+
if x isa Number
173+
return only(hv_vec)
174+
else
175+
return reshape(hv_vec, size(x))
176+
end
89177
end
90178

91-
function DI.value_and_jacobian!!(f, backend::AnyAutoFastDifferentiation, x, extras)
92-
return DI.value_and_jacobian(f, backend, x, extras)
179+
function DI.hvp!!(f, p, backend::AnyAutoFastDifferentiation, x, v, hvp_exe)
180+
return DI.hvp(f, backend, x, v, hvp_exe)
93181
end
94182

95183
## Hessian
96184

97-
function DI.prepare_hessian(f, ::AnyAutoFastDifferentiation, x)
185+
function DI.prepare_hessian(f, backend::AnyAutoFastDifferentiation, x)
98186
x_vec_var = make_variables(:x, size(x)...)
99187
y_vec_var = f(x_vec_var)
100-
hess_var = hessian(y_vec_var, vec(x_vec_var))
188+
if issparse(backend)
189+
hess_var = sparse_hessian(y_vec_var, vec(x_vec_var))
190+
else
191+
hess_var = hessian(y_vec_var, vec(x_vec_var))
192+
end
101193
hess_exe = make_function(hess_var, vec(x_vec_var); in_place=false)
102194
return hess_exe
103195
end
@@ -108,11 +200,6 @@ function DI.hessian(
108200
return hess_exe(vec(x))
109201
end
110202

111-
function DI.hessian(f, backend::AnyAutoFastDifferentiation, x, extras::Nothing)
112-
hess_exe = prepare_hessian(f, backend, x)
203+
function DI.hessian!!(f, hess, backend::AnyAutoFastDifferentiation, x, hess_exe)
113204
return DI.hessian(f, backend, x, hess_exe)
114205
end
115-
116-
function DI.hessian!!(f, backend::AnyAutoFastDifferentiation, x, extras)
117-
return DI.hessian(f, backend, x, extras)
118-
end

ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,8 @@ function DI.hessian(f, ::AnyAutoZygote, x, extras::Nothing)
6161
return hessian(f, x)
6262
end
6363

64+
function DI.hessian!!(f, hess, backend::AnyAutoZygote, x, extras::Nothing)
65+
return DI.hessian(f, backend, x, extras)
66+
end
67+
6468
end

lib/DifferentiationInterfaceTest/src/utils/printing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pretty(::AutoTapir) = "Tapir"
1616
pretty(::AutoTracker) = "Tracker"
1717
pretty(::AutoZygote) = "Zygote"
1818

19+
pretty(::AutoSparseFastDifferentiation) = "FastDifferentiation sparse"
1920
pretty(::AutoSparseFiniteDiff) = "FiniteDiff sparse"
2021
pretty(::AutoSparseForwardDiff) = "ForwardDiff sparse"
2122
pretty(::AutoSparsePolyesterForwardDiff) = "PolyesterForwardDiff sparse"

src/hessian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
hessian(f, backend, x, [extras]) -> hess
55
"""
66
function hessian(f, backend::AbstractADType, x, extras=prepare_hessian(f, backend, x))
7-
new_backend = SecondOrder(backend, backend)
7+
new_backend = SecondOrder(backend)
88
new_extras = prepare_hessian(f, new_backend, x)
99
return hessian(f, new_backend, x, new_extras)
1010
end
@@ -24,7 +24,7 @@ end
2424
function hessian!!(
2525
f, hess, backend::AbstractADType, x, extras=prepare_hessian(f, backend, x)
2626
)
27-
new_backend = SecondOrder(backend, backend)
27+
new_backend = SecondOrder(backend)
2828
new_extras = prepare_hessian(f, new_backend, x)
2929
return hessian!!(f, hess, new_backend, x, new_extras)
3030
end

src/hvp.jl

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ By order of preference:
1414
hvp(f, backend, x, v, [extras]) -> p
1515
"""
1616
function hvp(f, backend::AbstractADType, x, v, extras=prepare_hvp(f, backend, x))
17-
new_backend = SecondOrder(backend, backend)
17+
new_backend = SecondOrder(backend)
1818
new_extras = prepare_hvp(f, new_backend, x)
1919
return hvp(f, new_backend, x, v, new_extras)
2020
end
@@ -32,26 +32,32 @@ end
3232

3333
function hvp_aux(f, backend, x, v, extras, ::ForwardOverReverse)
3434
# JVP of the gradient
35-
inner_extras = prepare_gradient(extras, f, inner(backend), x)
36-
gradient_closure(z) = gradient(f, inner(backend), z, inner_extras)
35+
function gradient_closure(z)
36+
inner_extras = prepare_gradient(extras, f, inner(backend), z)
37+
return gradient(f, inner(backend), z, inner_extras)
38+
end
3739
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
3840
p = pushforward(gradient_closure, outer(backend), x, v, outer_extras)
3941
return p
4042
end
4143

4244
function hvp_aux(f, backend, x, v, extras, ::ReverseOverForward)
4345
# gradient of the JVP
44-
inner_extras = prepare_pushforward(extras, f, inner(backend), x)
45-
jvp_closure(z) = pushforward(f, inner(backend), z, v, inner_extras)
46+
function jvp_closure(z)
47+
inner_extras = prepare_pushforward(extras, f, inner(backend), z)
48+
return pushforward(f, inner(backend), z, v, inner_extras)
49+
end
4650
outer_extras = prepare_gradient(extras, jvp_closure, outer(backend), x)
4751
p = gradient(jvp_closure, outer(backend), x, outer_extras)
4852
return p
4953
end
5054

5155
function hvp_aux(f, backend, x, v, extras, ::ReverseOverReverse)
5256
# VJP of the gradient
53-
inner_extras = prepare_gradient(extras, f, inner(backend), x)
54-
gradient_closure(z) = gradient(f, inner(backend), z, inner_extras)
57+
function gradient_closure(z)
58+
inner_extras = prepare_gradient(extras, f, inner(backend), z)
59+
return gradient(f, inner(backend), z, inner_extras)
60+
end
5561
outer_extras = prepare_pullback(extras, gradient_closure, outer(backend), x)
5662
p = pullback(gradient_closure, outer(backend), x, v, outer_extras)
5763
return p
@@ -60,8 +66,10 @@ end
6066
function hvp_aux(f, backend, x, v, extras, ::ForwardOverForward)
6167
# JVPs of JVPs in theory
6268
# also pushforward of gradient in practice
63-
inner_extras = prepare_gradient(extras, f, inner(backend), x)
64-
gradient_closure(z) = gradient(f, inner(backend), z, nothing) # TODO: fix
69+
function gradient_closure(z)
70+
inner_extras = prepare_gradient(extras, f, inner(backend), z)
71+
return gradient(f, inner(backend), z, inner_extras)
72+
end
6573
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
6674
p = pushforward(gradient_closure, outer(backend), x, v, outer_extras)
6775
return p
@@ -71,7 +79,7 @@ end
7179
hvp!!(f, p, backend, x, v, [extras]) -> p
7280
"""
7381
function hvp!!(f, p, backend::AbstractADType, x, v, extras=prepare_hvp(f, backend, x))
74-
new_backend = SecondOrder(backend, backend)
82+
new_backend = SecondOrder(backend)
7583
new_extras = prepare_hvp(f, new_backend, x)
7684
return hvp!!(f, p, new_backend, x, v, new_extras)
7785
end
@@ -87,32 +95,40 @@ function hvp!!(f, p, backend::SecondOrder, x, v, extras=prepare_hvp(f, backend,
8795
end
8896

8997
function hvp_aux!!(f, p, backend, x, v, extras, ::ForwardOverReverse)
90-
inner_extras = prepare_gradient(extras, f, inner(backend), x)
91-
gradient_closure(z) = gradient(f, inner(backend), z, inner_extras)
98+
function gradient_closure(z)
99+
inner_extras = prepare_gradient(extras, f, inner(backend), z)
100+
return gradient(f, inner(backend), z, inner_extras)
101+
end
92102
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
93103
p = pushforward!!(gradient_closure, p, outer(backend), x, v, outer_extras)
94104
return p
95105
end
96106

97107
function hvp_aux!!(f, p, backend, x, v, extras, ::ReverseOverForward)
98-
inner_extras = prepare_pushforward(extras, f, inner(backend), x)
99-
jvp_closure(z) = pushforward(f, inner(backend), z, v, inner_extras)
108+
function jvp_closure(z)
109+
inner_extras = prepare_pushforward(extras, f, inner(backend), z)
110+
return pushforward(f, inner(backend), z, v, inner_extras)
111+
end
100112
outer_extras = prepare_gradient(extras, jvp_closure, outer(backend), x)
101113
p = gradient!!(jvp_closure, p, outer(backend), x, outer_extras)
102114
return p
103115
end
104116

105117
function hvp_aux!!(f, p, backend, x, v, extras, ::ReverseOverReverse)
106-
inner_extras = prepare_gradient(extras, f, inner(backend), x)
107-
gradient_closure(z) = gradient(f, inner(backend), z, inner_extras)
118+
function gradient_closure(z)
119+
inner_extras = prepare_gradient(extras, f, inner(backend), z)
120+
return gradient(f, inner(backend), z, inner_extras)
121+
end
108122
outer_extras = prepare_pullback(extras, gradient_closure, outer(backend), x)
109123
p = pullback!!(gradient_closure, p, outer(backend), x, v, outer_extras)
110124
return p
111125
end
112126

113127
function hvp_aux!!(f, p, backend, x, v, extras, ::ForwardOverForward)
114-
inner_extras = prepare_gradient(extras, f, inner(backend), x)
115-
gradient_closure(z) = gradient(f, inner(backend), z, nothing) # TODO: fix
128+
function gradient_closure(z)
129+
inner_extras = prepare_gradient(extras, f, inner(backend), z)
130+
return gradient(f, inner(backend), z, inner_extras)
131+
end
116132
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
117133
p = pushforward!!(gradient_closure, p, outer(backend), x, v, outer_extras)
118134
return p

src/second_derivative.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
second_derivative(f, backend, x, [extras]) -> der2
55
"""
66
function second_derivative(
7-
f, backend::AbstractADType, x::Number, extras=prepare_second_derivative(f, backend, x)
7+
f, backend::AbstractADType, x, extras=prepare_second_derivative(f, backend, x)
88
)
99
new_backend = SecondOrder(backend, backend)
1010
new_extras = prepare_second_derivative(f, new_backend, x)
1111
return second_derivative(f, new_backend, x, new_extras)
1212
end
1313

1414
function second_derivative(
15-
f, backend::SecondOrder, x::Number, extras=prepare_second_derivative(f, backend, x)
15+
f, backend::SecondOrder, x, extras=prepare_second_derivative(f, backend, x)
1616
)
17-
inner_extras = prepare_derivative(extras, f, inner(backend), x)
18-
derivative_closure(z) = derivative(f, inner(backend), z, inner_extras)
17+
function derivative_closure(z)
18+
inner_extras = prepare_derivative(extras, f, inner(backend), z)
19+
return derivative(f, inner(backend), z, inner_extras)
20+
end
1921
outer_extras = prepare_derivative(extras, derivative_closure, outer(backend), x)
2022
der2 = derivative(derivative_closure, outer(backend), x, outer_extras)
2123
return der2
@@ -25,26 +27,20 @@ end
2527
second_derivative!!(f, der2, backend, x, [extras]) -> der2
2628
"""
2729
function second_derivative!!(
28-
f,
29-
der2,
30-
backend::AbstractADType,
31-
x::Number,
32-
extras=prepare_second_derivative(f, backend, x),
30+
f, der2, backend::AbstractADType, x, extras=prepare_second_derivative(f, backend, x)
3331
)
3432
new_backend = SecondOrder(backend, backend)
3533
new_extras = prepare_second_derivative(f, new_backend, x)
3634
return second_derivative!!(f, der2, new_backend, x, new_extras)
3735
end
3836

3937
function second_derivative!!(
40-
f,
41-
der2,
42-
backend::SecondOrder,
43-
x::Number,
44-
extras=prepare_second_derivative(f, backend, x),
38+
f, der2, backend::SecondOrder, x, extras=prepare_second_derivative(f, backend, x)
4539
)
46-
inner_extras = prepare_derivative(extras, f, inner(backend), x)
47-
derivative_closure(z) = derivative(f, inner(backend), z, inner_extras)
40+
function derivative_closure(z)
41+
inner_extras = prepare_derivative(extras, f, inner(backend), z)
42+
return derivative(f, inner(backend), z, inner_extras)
43+
end
4844
outer_extras = prepare_derivative(extras, derivative_closure, outer(backend), x)
4945
der2 = derivative!!(derivative_closure, der2, outer(backend), x, outer_extras)
5046
return der2

src/second_order.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ struct SecondOrder{AD1<:AbstractADType,AD2<:AbstractADType} <: AbstractADType
1414
inner::AD2
1515
end
1616

17+
SecondOrder(backend::AbstractADType) = SecondOrder(backend, backend)
18+
1719
inner(backend::SecondOrder) = backend.inner
1820
outer(backend::SecondOrder) = backend.outer
1921

0 commit comments

Comments
 (0)