11# # Pushforward
22
3- struct FiniteDiffTwoArgPushforwardPrep{R,A} <: DI.PushforwardPrep
3+ struct FiniteDiffTwoArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
4+ cache:: C
45 relstep:: R
56 absstep:: A
67end
78
89function DI. prepare_pushforward (
910 f!, y, backend:: AutoFiniteDiff , x, tx:: NTuple , contexts:: Vararg{DI.Context,C}
1011) where {C}
12+ cache = if x isa Number
13+ nothing
14+ else
15+ JVPCache (similar (x), similar (y), fdtype (backend))
16+ end
1117 relstep = if isnothing (backend. relstep)
1218 default_relstep (fdtype (backend), eltype (x))
1319 else
@@ -18,14 +24,13 @@ function DI.prepare_pushforward(
1824 else
1925 backend. relstep
2026 end
21- return FiniteDiffTwoArgPushforwardPrep (relstep, absstep)
22- return DI. NoPushforwardPrep ()
27+ return FiniteDiffTwoArgPushforwardPrep (cache, relstep, absstep)
2328end
2429
2530function DI. value_and_pushforward (
2631 f!,
2732 y,
28- prep:: FiniteDiffTwoArgPushforwardPrep ,
33+ prep:: FiniteDiffTwoArgPushforwardPrep{Nothing} ,
2934 backend:: AutoFiniteDiff ,
3035 x,
3136 tx:: NTuple ,
@@ -52,6 +57,84 @@ function DI.value_and_pushforward(
5257 return y, ty
5358end
5459
60+ function DI. pushforward (
61+ f!,
62+ y,
63+ prep:: FiniteDiffTwoArgPushforwardPrep{<:JVPCache} ,
64+ :: AutoFiniteDiff ,
65+ x,
66+ tx:: NTuple ,
67+ contexts:: Vararg{DI.Context,C} ,
68+ ) where {C}
69+ (; relstep, absstep) = prep
70+ fc! = DI. with_contexts (f!, contexts... )
71+ ty = map (tx) do dx
72+ dy = similar (y)
73+ finite_difference_jvp! (dy, fc!, x, dx, prep. cache; relstep, absstep)
74+ dy
75+ end
76+ return ty
77+ end
78+
79+ function DI. value_and_pushforward (
80+ f!,
81+ y,
82+ prep:: FiniteDiffTwoArgPushforwardPrep{<:JVPCache} ,
83+ :: AutoFiniteDiff ,
84+ x,
85+ tx:: NTuple ,
86+ contexts:: Vararg{DI.Context,C} ,
87+ ) where {C}
88+ (; relstep, absstep) = prep
89+ fc! = DI. with_contexts (f!, contexts... )
90+ ty = map (tx) do dx
91+ dy = similar (y)
92+ finite_difference_jvp! (dy, fc!, x, dx, prep. cache; relstep, absstep)
93+ dy
94+ end
95+ fc! (y, x)
96+ return y, ty
97+ end
98+
99+ function DI. pushforward! (
100+ f!,
101+ y,
102+ ty:: NTuple ,
103+ prep:: FiniteDiffTwoArgPushforwardPrep{<:JVPCache} ,
104+ :: AutoFiniteDiff ,
105+ x,
106+ tx:: NTuple ,
107+ contexts:: Vararg{DI.Context,C} ,
108+ ) where {C}
109+ (; relstep, absstep) = prep
110+ fc! = DI. with_contexts (f!, contexts... )
111+ for b in eachindex (tx, ty)
112+ dx, dy = tx[b], ty[b]
113+ finite_difference_jvp! (dy, fc!, x, dx, prep. cache; relstep, absstep)
114+ end
115+ return ty
116+ end
117+
118+ function DI. value_and_pushforward! (
119+ f!,
120+ y,
121+ ty:: NTuple ,
122+ prep:: FiniteDiffTwoArgPushforwardPrep{<:JVPCache} ,
123+ :: AutoFiniteDiff ,
124+ x,
125+ tx:: NTuple ,
126+ contexts:: Vararg{DI.Context,C} ,
127+ ) where {C}
128+ (; relstep, absstep) = prep
129+ fc! = DI. with_contexts (f!, contexts... )
130+ for b in eachindex (tx, ty)
131+ dx, dy = tx[b], ty[b]
132+ finite_difference_jvp! (dy, fc!, x, dx, prep. cache; relstep, absstep)
133+ end
134+ fc! (y, x)
135+ return y, ty
136+ end
137+
55138# # Derivative
56139
57140struct FiniteDiffTwoArgDerivativePrep{C,R,A} <: DI.DerivativePrep
0 commit comments