11# # Pushforward
22
3- struct FiniteDiffOneArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
3+ struct FiniteDiffOneArgPushforwardPrep{C,R,A,D } <: DI.PushforwardPrep
44 cache:: C
55 relstep:: R
66 absstep:: A
7+ dir:: D
78end
89
910function DI. prepare_pushforward (
@@ -26,7 +27,8 @@ function DI.prepare_pushforward(
2627 else
2728 backend. relstep
2829 end
29- return FiniteDiffOneArgPushforwardPrep (cache, relstep, absstep)
30+ dir = backend. dir
31+ return FiniteDiffOneArgPushforwardPrep (cache, relstep, absstep, dir)
3032end
3133
3234function DI. pushforward (
@@ -37,11 +39,11 @@ function DI.pushforward(
3739 tx:: NTuple ,
3840 contexts:: Vararg{DI.Context,C} ,
3941) where {C}
40- (; relstep, absstep) = prep
42+ (; relstep, absstep, dir ) = prep
4143 step (t:: Number , dx) = f (x .+ t .* dx, map (DI. unwrap, contexts)... )
4244 ty = map (tx) do dx
4345 finite_difference_derivative (
44- Base. Fix2 (step, dx), zero (eltype (x)), fdtype (backend); relstep, absstep
46+ Base. Fix2 (step, dx), zero (eltype (x)), fdtype (backend); relstep, absstep, dir
4547 )
4648 end
4749 return ty
@@ -55,7 +57,7 @@ function DI.value_and_pushforward(
5557 tx:: NTuple ,
5658 contexts:: Vararg{DI.Context,C} ,
5759) where {C}
58- (; relstep, absstep) = prep
60+ (; relstep, absstep, dir ) = prep
5961 step (t:: Number , dx) = f (x .+ t .* dx, map (DI. unwrap, contexts)... )
6062 y = f (x, map (DI. unwrap, contexts)... )
6163 ty = map (tx) do dx
@@ -67,6 +69,7 @@ function DI.value_and_pushforward(
6769 y;
6870 relstep,
6971 absstep,
72+ dir,
7073 )
7174 end
7275 return y, ty
@@ -80,10 +83,10 @@ function DI.pushforward(
8083 tx:: NTuple ,
8184 contexts:: Vararg{DI.Context,C} ,
8285) where {C}
83- (; relstep, absstep) = prep
86+ (; relstep, absstep, dir ) = prep
8487 fc = DI. with_contexts (f, contexts... )
8588 ty = map (tx) do dx
86- finite_difference_jvp (fc, x, dx, prep. cache; relstep, absstep)
89+ finite_difference_jvp (fc, x, dx, prep. cache; relstep, absstep, dir )
8790 end
8891 return ty
8992end
@@ -96,21 +99,22 @@ function DI.value_and_pushforward(
9699 tx:: NTuple ,
97100 contexts:: Vararg{DI.Context,C} ,
98101) where {C}
99- (; relstep, absstep) = prep
102+ (; relstep, absstep, dir ) = prep
100103 fc = DI. with_contexts (f, contexts... )
101104 y = fc (x)
102105 ty = map (tx) do dx
103- finite_difference_jvp (fc, x, dx, prep. cache, y; relstep, absstep)
106+ finite_difference_jvp (fc, x, dx, prep. cache, y; relstep, absstep, dir )
104107 end
105108 return y, ty
106109end
107110
108111# # Derivative
109112
110- struct FiniteDiffOneArgDerivativePrep{C,R,A} <: DI.DerivativePrep
113+ struct FiniteDiffOneArgDerivativePrep{C,R,A,D } <: DI.DerivativePrep
111114 cache:: C
112115 relstep:: R
113116 absstep:: A
117+ dir:: D
114118end
115119
116120function DI. prepare_derivative (
@@ -134,7 +138,8 @@ function DI.prepare_derivative(
134138 else
135139 backend. relstep
136140 end
137- return FiniteDiffOneArgDerivativePrep (cache, relstep, absstep)
141+ dir = backend. dir
142+ return FiniteDiffOneArgDerivativePrep (cache, relstep, absstep, dir)
138143end
139144
140145# ## Scalar to scalar
@@ -146,9 +151,9 @@ function DI.derivative(
146151 x,
147152 contexts:: Vararg{DI.Context,C} ,
148153) where {C}
149- (; relstep, absstep) = prep
154+ (; relstep, absstep, dir ) = prep
150155 fc = DI. with_contexts (f, contexts... )
151- return finite_difference_derivative (fc, x, fdtype (backend); relstep, absstep)
156+ return finite_difference_derivative (fc, x, fdtype (backend); relstep, absstep, dir )
152157end
153158
154159function DI. value_and_derivative (
@@ -158,13 +163,13 @@ function DI.value_and_derivative(
158163 x,
159164 contexts:: Vararg{DI.Context,C} ,
160165) where {C}
161- (; relstep, absstep) = prep
166+ (; relstep, absstep, dir ) = prep
162167 fc = DI. with_contexts (f, contexts... )
163168 y = fc (x)
164169 return (
165170 y,
166171 finite_difference_derivative (
167- fc, x, fdtype (backend), eltype (y), y; relstep, absstep
172+ fc, x, fdtype (backend), eltype (y), y; relstep, absstep, dir
168173 ),
169174 )
170175end
@@ -178,9 +183,9 @@ function DI.derivative(
178183 x,
179184 contexts:: Vararg{DI.Context,C} ,
180185) where {C}
181- (; relstep, absstep) = prep
186+ (; relstep, absstep, dir ) = prep
182187 fc = DI. with_contexts (f, contexts... )
183- return finite_difference_gradient (fc, x, prep. cache; relstep, absstep)
188+ return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir )
184189end
185190
186191function DI. derivative! (
@@ -191,9 +196,9 @@ function DI.derivative!(
191196 x,
192197 contexts:: Vararg{DI.Context,C} ,
193198) where {C}
194- (; relstep, absstep) = prep
199+ (; relstep, absstep, dir ) = prep
195200 fc = DI. with_contexts (f, contexts... )
196- return finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep)
201+ return finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir )
197202end
198203
199204function DI. value_and_derivative (
@@ -204,9 +209,9 @@ function DI.value_and_derivative(
204209 contexts:: Vararg{DI.Context,C} ,
205210) where {C}
206211 fc = DI. with_contexts (f, contexts... )
207- (; relstep, absstep) = prep
212+ (; relstep, absstep, dir ) = prep
208213 y = fc (x)
209- return (y, finite_difference_gradient (fc, x, prep. cache; relstep, absstep))
214+ return (y, finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir ))
210215end
211216
212217function DI. value_and_derivative! (
@@ -217,17 +222,20 @@ function DI.value_and_derivative!(
217222 x,
218223 contexts:: Vararg{DI.Context,C} ,
219224) where {C}
220- (; relstep, absstep) = prep
225+ (; relstep, absstep, dir ) = prep
221226 fc = DI. with_contexts (f, contexts... )
222- return (fc (x), finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep))
227+ return (
228+ fc (x), finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir)
229+ )
223230end
224231
225232# # Gradient
226233
227- struct FiniteDiffGradientPrep{C,R,A} <: DI.GradientPrep
234+ struct FiniteDiffGradientPrep{C,R,A,D } <: DI.GradientPrep
228235 cache:: C
229236 relstep:: R
230237 absstep:: A
238+ dir:: D
231239end
232240
233241function DI. prepare_gradient (
@@ -247,7 +255,8 @@ function DI.prepare_gradient(
247255 else
248256 backend. relstep
249257 end
250- return FiniteDiffGradientPrep (cache, relstep, absstep)
258+ dir = backend. dir
259+ return FiniteDiffGradientPrep (cache, relstep, absstep, dir)
251260end
252261
253262function DI. gradient (
@@ -257,9 +266,9 @@ function DI.gradient(
257266 x:: AbstractArray ,
258267 contexts:: Vararg{DI.Context,C} ,
259268) where {C}
260- (; relstep, absstep) = prep
269+ (; relstep, absstep, dir ) = prep
261270 fc = DI. with_contexts (f, contexts... )
262- return finite_difference_gradient (fc, x, prep. cache; relstep, absstep)
271+ return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir )
263272end
264273
265274function DI. value_and_gradient (
@@ -269,9 +278,9 @@ function DI.value_and_gradient(
269278 x:: AbstractArray ,
270279 contexts:: Vararg{DI.Context,C} ,
271280) where {C}
272- (; relstep, absstep) = prep
281+ (; relstep, absstep, dir ) = prep
273282 fc = DI. with_contexts (f, contexts... )
274- return fc (x), finite_difference_gradient (fc, x, prep. cache; relstep, absstep)
283+ return fc (x), finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir )
275284end
276285
277286function DI. gradient! (
@@ -282,9 +291,9 @@ function DI.gradient!(
282291 x:: AbstractArray ,
283292 contexts:: Vararg{DI.Context,C} ,
284293) where {C}
285- (; relstep, absstep) = prep
294+ (; relstep, absstep, dir ) = prep
286295 fc = DI. with_contexts (f, contexts... )
287- return finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep)
296+ return finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir )
288297end
289298
290299function DI. value_and_gradient! (
@@ -295,17 +304,20 @@ function DI.value_and_gradient!(
295304 x:: AbstractArray ,
296305 contexts:: Vararg{DI.Context,C} ,
297306) where {C}
298- (; relstep, absstep) = prep
307+ (; relstep, absstep, dir ) = prep
299308 fc = DI. with_contexts (f, contexts... )
300- return (fc (x), finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep))
309+ return (
310+ fc (x), finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir)
311+ )
301312end
302313
303314# # Jacobian
304315
305- struct FiniteDiffOneArgJacobianPrep{C,R,A} <: DI.JacobianPrep
316+ struct FiniteDiffOneArgJacobianPrep{C,R,A,D } <: DI.JacobianPrep
306317 cache:: C
307318 relstep:: R
308319 absstep:: A
320+ dir:: D
309321end
310322
311323function DI. prepare_jacobian (
@@ -327,7 +339,8 @@ function DI.prepare_jacobian(
327339 else
328340 backend. relstep
329341 end
330- return FiniteDiffOneArgJacobianPrep (cache, relstep, absstep)
342+ dir = backend. dir
343+ return FiniteDiffOneArgJacobianPrep (cache, relstep, absstep, dir)
331344end
332345
333346function DI. jacobian (
@@ -337,9 +350,9 @@ function DI.jacobian(
337350 x,
338351 contexts:: Vararg{DI.Context,C} ,
339352) where {C}
340- (; relstep, absstep) = prep
353+ (; relstep, absstep, dir ) = prep
341354 fc = DI. with_contexts (f, contexts... )
342- return finite_difference_jacobian (fc, x, prep. cache; relstep, absstep)
355+ return finite_difference_jacobian (fc, x, prep. cache; relstep, absstep, dir )
343356end
344357
345358function DI. value_and_jacobian (
@@ -350,9 +363,9 @@ function DI.value_and_jacobian(
350363 contexts:: Vararg{DI.Context,C} ,
351364) where {C}
352365 fc = DI. with_contexts (f, contexts... )
353- (; relstep, absstep) = prep
366+ (; relstep, absstep, dir ) = prep
354367 y = fc (x)
355- return (y, finite_difference_jacobian (fc, x, prep. cache, y; relstep, absstep))
368+ return (y, finite_difference_jacobian (fc, x, prep. cache, y; relstep, absstep, dir ))
356369end
357370
358371function DI. jacobian! (
@@ -363,11 +376,13 @@ function DI.jacobian!(
363376 x,
364377 contexts:: Vararg{DI.Context,C} ,
365378) where {C}
366- (; relstep, absstep) = prep
379+ (; relstep, absstep, dir ) = prep
367380 fc = DI. with_contexts (f, contexts... )
368381 return copyto! (
369382 jac,
370- finite_difference_jacobian (fc, x, prep. cache; jac_prototype= jac, relstep, absstep),
383+ finite_difference_jacobian (
384+ fc, x, prep. cache; jac_prototype= jac, relstep, absstep, dir
385+ ),
371386 )
372387end
373388
@@ -379,15 +394,15 @@ function DI.value_and_jacobian!(
379394 x,
380395 contexts:: Vararg{DI.Context,C} ,
381396) where {C}
382- (; relstep, absstep) = prep
397+ (; relstep, absstep, dir ) = prep
383398 fc = DI. with_contexts (f, contexts... )
384399 y = fc (x)
385400 return (
386401 y,
387402 copyto! (
388403 jac,
389404 finite_difference_jacobian (
390- fc, x, prep. cache, y; jac_prototype= jac, relstep, absstep
405+ fc, x, prep. cache, y; jac_prototype= jac, relstep, absstep, dir
391406 ),
392407 ),
393408 )
0 commit comments