@@ -3,7 +3,14 @@ module DifferentiationInterfaceFiniteDifferencesExt
33using ADTypes: AutoFiniteDifferences
44import DifferentiationInterface as DI
55using DifferentiationInterface:
6- NoGradientPrep, NoJacobianPrep, NoPullbackPrep, NoPushforwardPrep, Tangents
6+ Context,
7+ NoGradientPrep,
8+ NoJacobianPrep,
9+ NoPullbackPrep,
10+ NoPushforwardPrep,
11+ Tangents,
12+ unwrap,
13+ with_contexts
714using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
815using LinearAlgebra: dot
916
@@ -12,85 +19,158 @@ DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported()
1219
1320# # Pushforward
1421
15- function DI. prepare_pushforward (f, :: AutoFiniteDifferences , x, tx:: Tangents )
22+ function DI. prepare_pushforward (
23+ f, :: AutoFiniteDifferences , x, tx:: Tangents , contexts:: Vararg{Context,C}
24+ ) where {C}
1625 return NoPushforwardPrep ()
1726end
1827
1928function DI. pushforward (
20- f, :: NoPushforwardPrep , backend:: AutoFiniteDifferences , x, tx:: Tangents
21- )
29+ f,
30+ :: NoPushforwardPrep ,
31+ backend:: AutoFiniteDifferences ,
32+ x,
33+ tx:: Tangents ,
34+ contexts:: Vararg{Context,C} ,
35+ ) where {C}
36+ fc = with_contexts (f, contexts... )
2237 ty = map (tx) do dx
23- jvp (backend. fdm, f , (x, dx))
38+ jvp (backend. fdm, fc , (x, dx))
2439 end
2540 return ty
2641end
2742
2843function DI. value_and_pushforward (
29- f, prep:: NoPushforwardPrep , backend:: AutoFiniteDifferences , x, tx:: Tangents
30- )
31- return f (x), DI. pushforward (f, prep, backend, x, tx)
44+ f,
45+ prep:: NoPushforwardPrep ,
46+ backend:: AutoFiniteDifferences ,
47+ x,
48+ tx:: Tangents ,
49+ contexts:: Vararg{Context,C} ,
50+ ) where {C}
51+ return f (x, map (unwrap, contexts)... ),
52+ DI. pushforward (f, prep, backend, x, tx, contexts... )
3253end
3354
3455# # Pullback
3556
36- DI. prepare_pullback (f, :: AutoFiniteDifferences , x, ty:: Tangents ) = NoPullbackPrep ()
57+ function DI. prepare_pullback (
58+ f, :: AutoFiniteDifferences , x, ty:: Tangents , contexts:: Vararg{Context,C}
59+ ) where {C}
60+ return NoPullbackPrep ()
61+ end
3762
38- function DI. pullback (f, :: NoPullbackPrep , backend:: AutoFiniteDifferences , x, ty:: Tangents )
63+ function DI. pullback (
64+ f,
65+ :: NoPullbackPrep ,
66+ backend:: AutoFiniteDifferences ,
67+ x,
68+ ty:: Tangents ,
69+ contexts:: Vararg{Context,C} ,
70+ ) where {C}
71+ fc = with_contexts (f, contexts... )
3972 tx = map (ty) do dy
40- only (j′vp (backend. fdm, f , dy, x))
73+ only (j′vp (backend. fdm, fc , dy, x))
4174 end
4275 return tx
4376end
4477
4578function DI. value_and_pullback (
46- f, prep:: NoPullbackPrep , backend:: AutoFiniteDifferences , x, ty:: Tangents
47- )
48- return f (x), DI. pullback (f, prep, backend, x, ty)
79+ f,
80+ prep:: NoPullbackPrep ,
81+ backend:: AutoFiniteDifferences ,
82+ x,
83+ ty:: Tangents ,
84+ contexts:: Vararg{Context,C} ,
85+ ) where {C}
86+ return f (x, map (unwrap, contexts)... ), DI. pullback (f, prep, backend, x, ty, contexts... )
4987end
5088
5189# # Gradient
5290
53- DI. prepare_gradient (f, :: AutoFiniteDifferences , x) = NoGradientPrep ()
91+ function DI. prepare_gradient (
92+ f, :: AutoFiniteDifferences , x, contexts:: Vararg{Context,C}
93+ ) where {C}
94+ return NoGradientPrep ()
95+ end
5496
55- function DI. gradient (f, :: NoGradientPrep , backend:: AutoFiniteDifferences , x)
56- return only (grad (backend. fdm, f, x))
97+ function DI. gradient (
98+ f, :: NoGradientPrep , backend:: AutoFiniteDifferences , x, contexts:: Vararg{Context,C}
99+ ) where {C}
100+ fc = with_contexts (f, contexts... )
101+ return only (grad (backend. fdm, fc, x))
57102end
58103
59- function DI. value_and_gradient (f, prep:: NoGradientPrep , backend:: AutoFiniteDifferences , x)
60- return f (x), DI. gradient (f, prep, backend, x)
104+ function DI. value_and_gradient (
105+ f, prep:: NoGradientPrep , backend:: AutoFiniteDifferences , x, contexts:: Vararg{Context,C}
106+ ) where {C}
107+ return f (x, map (unwrap, contexts)... ), DI. gradient (f, prep, backend, x, contexts... )
61108end
62109
63- function DI. gradient! (f, grad, prep:: NoGradientPrep , backend:: AutoFiniteDifferences , x)
64- return copyto! (grad, DI. gradient (f, prep, backend, x))
110+ function DI. gradient! (
111+ f,
112+ grad,
113+ prep:: NoGradientPrep ,
114+ backend:: AutoFiniteDifferences ,
115+ x,
116+ contexts:: Vararg{Context,C} ,
117+ ) where {C}
118+ return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
65119end
66120
67121function DI. value_and_gradient! (
68- f, grad, prep:: NoGradientPrep , backend:: AutoFiniteDifferences , x
69- )
70- y, new_grad = DI. value_and_gradient (f, prep, backend, x)
122+ f,
123+ grad,
124+ prep:: NoGradientPrep ,
125+ backend:: AutoFiniteDifferences ,
126+ x,
127+ contexts:: Vararg{Context,C} ,
128+ ) where {C}
129+ y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
71130 return y, copyto! (grad, new_grad)
72131end
73132
74133# # Jacobian
75134
76- DI. prepare_jacobian (f, :: AutoFiniteDifferences , x) = NoJacobianPrep ()
135+ function DI. prepare_jacobian (
136+ f, :: AutoFiniteDifferences , x, contexts:: Vararg{Context,C}
137+ ) where {C}
138+ return NoJacobianPrep ()
139+ end
77140
78- function DI. jacobian (f, :: NoJacobianPrep , backend:: AutoFiniteDifferences , x)
79- return only (jacobian (backend. fdm, f, x))
141+ function DI. jacobian (
142+ f, :: NoJacobianPrep , backend:: AutoFiniteDifferences , x, contexts:: Vararg{Context,C}
143+ ) where {C}
144+ fc = with_contexts (f, contexts... )
145+ return only (jacobian (backend. fdm, fc, x))
80146end
81147
82- function DI. value_and_jacobian (f, prep:: NoJacobianPrep , backend:: AutoFiniteDifferences , x)
83- return f (x), DI. jacobian (f, prep, backend, x)
148+ function DI. value_and_jacobian (
149+ f, prep:: NoJacobianPrep , backend:: AutoFiniteDifferences , x, contexts:: Vararg{Context,C}
150+ ) where {C}
151+ return f (x, map (unwrap, contexts)... ), DI. jacobian (f, prep, backend, x, contexts... )
84152end
85153
86- function DI. jacobian! (f, jac, prep:: NoJacobianPrep , backend:: AutoFiniteDifferences , x)
87- return copyto! (jac, DI. jacobian (f, prep, backend, x))
154+ function DI. jacobian! (
155+ f,
156+ jac,
157+ prep:: NoJacobianPrep ,
158+ backend:: AutoFiniteDifferences ,
159+ x,
160+ contexts:: Vararg{Context,C} ,
161+ ) where {C}
162+ return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
88163end
89164
90165function DI. value_and_jacobian! (
91- f, jac, prep:: NoJacobianPrep , backend:: AutoFiniteDifferences , x
92- )
93- y, new_jac = DI. value_and_jacobian (f, prep, backend, x)
166+ f,
167+ jac,
168+ prep:: NoJacobianPrep ,
169+ backend:: AutoFiniteDifferences ,
170+ x,
171+ contexts:: Vararg{Context,C} ,
172+ ) where {C}
173+ y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
94174 return y, copyto! (jac, new_jac)
95175end
96176
0 commit comments