Skip to content

Commit be5fc79

Browse files
authored
Preparation for FiniteDiff (#154)
* Add prep for gradient jacobian hessian * Add prep for gradient jacobian hessian * Full first order tests * Add derivative * Fix second order
1 parent 65c693a commit be5fc79

5 files changed

Lines changed: 156 additions & 58 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ using DifferentiationInterface:
1010
NoPullbackExtras,
1111
NoPushforwardExtras
1212
using FiniteDiff:
13+
DerivativeCache,
14+
GradientCache,
15+
HessianCache,
16+
JacobianCache,
1317
finite_difference_derivative,
1418
finite_difference_gradient,
1519
finite_difference_gradient!,

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/allocating.jl

Lines changed: 128 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,86 +21,182 @@ end
2121

2222
## Derivative
2323

24-
DI.prepare_derivative(f, ::AnyAutoFiniteDiff, x) = NoDerivativeExtras()
24+
struct FiniteDiffAllocatingDerivativeExtras{C}
25+
cache::C
26+
end
27+
28+
function DI.prepare_derivative(f, backend::AnyAutoFiniteDiff, x)
29+
y = f(x)
30+
cache = if y isa Number
31+
nothing
32+
elseif y isa AbstractArray
33+
df = similar(y)
34+
cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_NOT_INPLACE)
35+
end
36+
return FiniteDiffAllocatingDerivativeExtras(cache)
37+
end
38+
39+
### Scalar to scalar
2540

26-
function DI.derivative(f, backend::AnyAutoFiniteDiff, x, ::NoDerivativeExtras)
41+
function DI.derivative(
42+
f, backend::AnyAutoFiniteDiff, x, ::FiniteDiffAllocatingDerivativeExtras{Nothing}
43+
)
2744
return finite_difference_derivative(f, x, fdtype(backend))
2845
end
2946

30-
function DI.value_and_derivative(f, backend::AnyAutoFiniteDiff, x, ::NoDerivativeExtras)
47+
function DI.derivative!!(
48+
f,
49+
_der,
50+
backend::AnyAutoFiniteDiff,
51+
x,
52+
extras::FiniteDiffAllocatingDerivativeExtras{Nothing},
53+
)
54+
return DI.derivative(f, backend, x, extras)
55+
end
56+
57+
function DI.value_and_derivative(
58+
f, backend::AnyAutoFiniteDiff, x, ::FiniteDiffAllocatingDerivativeExtras{Nothing}
59+
)
3160
y = f(x)
3261
return y, finite_difference_derivative(f, x, fdtype(backend), eltype(y), y)
3362
end
3463

35-
function DI.derivative!!(f, der, backend::AnyAutoFiniteDiff, x, extras::NoDerivativeExtras)
36-
return DI.derivative(f, backend, x, extras)
64+
function DI.value_and_derivative!!(
65+
f,
66+
_der,
67+
backend::AnyAutoFiniteDiff,
68+
x,
69+
extras::FiniteDiffAllocatingDerivativeExtras{Nothing},
70+
)
71+
return DI.value_and_derivative(f, backend, x, extras)
72+
end
73+
74+
### Scalar to array
75+
76+
function DI.derivative(
77+
f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache}
78+
)
79+
return finite_difference_gradient(f, x, extras.cache)
80+
end
81+
82+
function DI.derivative!!(
83+
f,
84+
der,
85+
::AnyAutoFiniteDiff,
86+
x,
87+
extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache},
88+
)
89+
return finite_difference_gradient!(der, f, x, extras.cache)
90+
end
91+
92+
function DI.value_and_derivative(
93+
f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache}
94+
)
95+
y = f(x)
96+
return y, finite_difference_gradient(f, x, extras.cache)
3797
end
3898

3999
function DI.value_and_derivative!!(
40-
f, der, backend::AnyAutoFiniteDiff, x, extras::NoDerivativeExtras
100+
f,
101+
der,
102+
::AnyAutoFiniteDiff,
103+
x,
104+
extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache},
41105
)
42-
return DI.value_and_derivative(f, backend, x, extras)
106+
return f(x), finite_difference_gradient!(der, f, x, extras.cache)
43107
end
44108

45109
## Gradient
46110

47-
DI.prepare_gradient(f, ::AnyAutoFiniteDiff, x) = NoGradientExtras()
111+
struct FiniteDiffGradientExtras{C}
112+
cache::C
113+
end
48114

49-
function DI.gradient(f, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras)
50-
return finite_difference_gradient(f, x, fdtype(backend))
115+
function DI.prepare_gradient(f, backend::AnyAutoFiniteDiff, x)
116+
y = f(x)
117+
df = zero(y) .* x
118+
cache = GradientCache(df, x, fdtype(backend))
119+
return FiniteDiffGradientExtras(cache)
120+
end
121+
122+
function DI.gradient(
123+
f, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras
124+
)
125+
return finite_difference_gradient(f, x, extras.cache)
51126
end
52127

53128
function DI.value_and_gradient(
54-
f, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras
129+
f, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras
55130
)
56-
y = f(x)
57-
return y, finite_difference_gradient(f, x, fdtype(backend), typeof(y), y)
131+
return f(x), finite_difference_gradient(f, x, extras.cache)
58132
end
59133

60134
function DI.gradient!!(
61-
f, grad, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras
135+
f, grad, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras
62136
)
63-
return finite_difference_gradient!(grad, f, x, fdtype(backend))
137+
return finite_difference_gradient!(grad, f, x, extras.cache)
64138
end
65139

66140
function DI.value_and_gradient!!(
67-
f, grad, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras
141+
f, grad, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras
68142
)
69-
y = f(x)
70-
return y, finite_difference_gradient!(grad, f, x, fdtype(backend), typeof(y), y)
143+
return f(x), finite_difference_gradient!(grad, f, x, extras.cache)
71144
end
72145

73146
## Jacobian
74147

75-
DI.prepare_jacobian(f, ::AnyAutoFiniteDiff, x) = NoJacobianExtras()
148+
struct FiniteDiffAllocatingJacobianExtras{C}
149+
cache::C
150+
end
151+
152+
function DI.prepare_jacobian(f, backend::AnyAutoFiniteDiff, x)
153+
y = f(x)
154+
x1 = similar(x)
155+
fx = similar(y)
156+
fx1 = similar(y)
157+
cache = JacobianCache(x1, fx, fx1, fdjtype(backend))
158+
return FiniteDiffAllocatingJacobianExtras(cache)
159+
end
76160

77-
function DI.jacobian(f, backend::AnyAutoFiniteDiff, x, ::NoJacobianExtras)
78-
return finite_difference_jacobian(f, x, fdjtype(backend))
161+
function DI.jacobian(f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras)
162+
return finite_difference_jacobian(f, x, extras.cache)
79163
end
80164

81-
function DI.value_and_jacobian(f, backend::AnyAutoFiniteDiff, x, ::NoJacobianExtras)
165+
function DI.value_and_jacobian(
166+
f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras
167+
)
82168
y = f(x)
83-
return y, finite_difference_jacobian(f, x, fdjtype(backend), eltype(y), y)
169+
return y, finite_difference_jacobian(f, x, extras.cache, y)
84170
end
85171

86-
function DI.jacobian!!(f, jac, backend::AnyAutoFiniteDiff, x, extras::NoJacobianExtras)
87-
return DI.jacobian(f, backend, x, extras)
172+
function DI.jacobian!!(
173+
f, jac, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras
174+
)
175+
return finite_difference_jacobian(f, x, extras.cache; jac_prototype=jac)
88176
end
89177

90178
function DI.value_and_jacobian!!(
91-
f, jac, backend::AnyAutoFiniteDiff, x, extras::NoJacobianExtras
179+
f, jac, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras
92180
)
93-
return DI.value_and_jacobian(f, backend, x, extras)
181+
y = f(x)
182+
return y, finite_difference_jacobian(f, x, extras.cache, y; jac_prototype=jac)
94183
end
95184

96185
## Hessian
97186

98-
DI.prepare_hessian(f, ::AnyAutoFiniteDiff, x) = NoHessianExtras()
187+
struct FiniteDiffHessianExtras{C}
188+
cache::C
189+
end
190+
191+
function DI.prepare_hessian(f, backend::AnyAutoFiniteDiff, x)
192+
cache = HessianCache(x, fdhtype(backend))
193+
return FiniteDiffHessianExtras(cache)
194+
end
99195

100-
function DI.hessian(f, backend::AnyAutoFiniteDiff, x, ::NoHessianExtras)
101-
return finite_difference_hessian(f, x, fdhtype(backend))
196+
function DI.hessian(f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
197+
return finite_difference_hessian(f, x, extras.cache)
102198
end
103199

104-
function DI.hessian!!(f, hess, backend::AnyAutoFiniteDiff, x, ::NoHessianExtras)
105-
return finite_difference_hessian!(hess, f, x, fdhtype(backend))
200+
function DI.hessian!!(f, hess, ::AnyAutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
201+
return finite_difference_hessian!(hess, f, x, extras.cache)
106202
end

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/mutating.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,22 @@ end
2525

2626
## Derivative
2727

28-
DI.prepare_derivative(f!, ::AnyAutoFiniteDiff, y, x) = NoDerivativeExtras()
28+
struct FiniteDiffMutatingDerivativeExtras{C}
29+
cache::C
30+
end
31+
32+
function DI.prepare_derivative(f!, ::AnyAutoFiniteDiff, y, x)
33+
cache = nothing
34+
return FiniteDiffMutatingDerivativeExtras(cache)
35+
end
2936

3037
function DI.value_and_derivative!!(
3138
f!,
3239
y::AbstractArray,
3340
der::AbstractArray,
3441
backend::AnyAutoFiniteDiff,
3542
x,
36-
::NoDerivativeExtras,
43+
::FiniteDiffMutatingDerivativeExtras,
3744
)
3845
f!(y, x)
3946
finite_difference_gradient!(der, f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y)
@@ -42,17 +49,27 @@ end
4249

4350
## Jacobian
4451

45-
DI.prepare_jacobian(f!, ::AnyAutoFiniteDiff, y, x) = NoJacobianExtras()
52+
struct FiniteDiffMutatingJacobianExtras{C}
53+
cache::C
54+
end
55+
56+
function DI.prepare_jacobian(f!, backend::AnyAutoFiniteDiff, y, x)
57+
x1 = similar(x)
58+
fx = similar(y)
59+
fx1 = similar(y)
60+
cache = JacobianCache(x1, fx, fx1, fdjtype(backend))
61+
return FiniteDiffMutatingJacobianExtras(cache)
62+
end
4663

4764
function DI.value_and_jacobian!!(
4865
f!,
4966
y::AbstractArray,
5067
jac::AbstractMatrix,
51-
backend::AnyAutoFiniteDiff,
68+
::AnyAutoFiniteDiff,
5269
x,
53-
::NoJacobianExtras,
70+
extras::FiniteDiffMutatingJacobianExtras,
5471
)
72+
finite_difference_jacobian!(jac, f!, x, extras.cache)
5573
f!(y, x)
56-
finite_difference_jacobian!(jac, f!, x, fdjtype(backend), eltype(y), y)
5774
return y, jac
5875
end

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/allocating.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,6 @@ function DI.pullback!!(
3434
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)[2]
3535
end
3636

37-
#=
38-
# First try
39-
40-
function DI.value_and_pullback_split(f, ::AutoTapir, x, extras::TapirAllocatingPullbackExtras)
41-
tf = zero_tangent(f)
42-
tx = zero_tangent(x)
43-
out, pb!! = extras.rrule(CoDual(f, tf), CoDual(x, tx))
44-
y = copy(primal(out))
45-
function pullbackfunc(dy)
46-
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
47-
ty = increment!!(tangent(out), dy_righttype)
48-
res = pb!!(ty, tf, tx)
49-
extras.rrule(CoDual(f, tf), CoDual(x, tx))
50-
return last(res)
51-
end
52-
return y, pullbackfunc
53-
end
54-
=#
55-
5637
function DI.value_and_pullback_split(
5738
f, backend::AutoTapir, x, extras::TapirAllocatingPullbackExtras
5839
)

DifferentiationInterface/test/second_order.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ second_order_mixed_backends = [
1111
# forward over reverse
1212
SecondOrder(AutoForwardDiff(), AutoZygote()),
1313
# reverse over forward
14-
SecondOrder(AutoZygote(), AutoFiniteDiff()),
14+
SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoForwardDiff()),
1515
# reverse over reverse
1616
SecondOrder(AutoReverseDiff(), AutoZygote()),
1717
]

0 commit comments

Comments
 (0)