@@ -78,6 +78,82 @@ function DI.pushforward!(
7878 return dy
7979end
8080
81+ # # Pullback
82+
83+ struct FastDifferentiationTwoArgPullbackExtras{E1,E2} <: PullbackExtras
84+ vjp_exe:: E1
85+ vjp_exe!:: E2
86+ end
87+
88+ function DI. prepare_pullback (f!, y, :: AutoFastDifferentiation , x, dy)
89+ x_var = if x isa Number
90+ only (make_variables (:x ))
91+ else
92+ make_variables (:x , size (x)... )
93+ end
94+ y_var = make_variables (:y , size (y)... )
95+ f! (y_var, x_var)
96+
97+ x_vec_var = x_var isa Number ? monovec (x_var) : vec (x_var)
98+ y_vec_var = y_var isa Number ? monovec (y_var) : vec (y_var)
99+ vj_vec_var, v_vec_var = jacobian_transpose_v (y_vec_var, x_vec_var)
100+ vjp_exe = make_function (vj_vec_var, vcat (x_vec_var, v_vec_var); in_place= false )
101+ vjp_exe! = make_function (vj_vec_var, vcat (x_vec_var, v_vec_var); in_place= true )
102+ return FastDifferentiationTwoArgPullbackExtras (vjp_exe, vjp_exe!)
103+ end
104+
105+ function DI. pullback (
106+ f!, y, :: AutoFastDifferentiation , x, dy, extras:: FastDifferentiationTwoArgPullbackExtras
107+ )
108+ v_vec = vcat (myvec (x), myvec (dy))
109+ if x isa Number
110+ return only (extras. vjp_exe (v_vec))
111+ else
112+ return reshape (extras. vjp_exe (v_vec), size (x))
113+ end
114+ end
115+
116+ function DI. pullback! (
117+ f!,
118+ y,
119+ dx,
120+ :: AutoFastDifferentiation ,
121+ x,
122+ dy,
123+ extras:: FastDifferentiationTwoArgPullbackExtras ,
124+ )
125+ v_vec = vcat (myvec (x), myvec (dy))
126+ extras. vjp_exe! (vec (dx), v_vec)
127+ return dx
128+ end
129+
130+ function DI. value_and_pullback (
131+ f!,
132+ y,
133+ backend:: AutoFastDifferentiation ,
134+ x,
135+ dy,
136+ extras:: FastDifferentiationTwoArgPullbackExtras ,
137+ )
138+ dx = DI. pullback (f!, y, backend, x, dy, extras)
139+ f! (y, x)
140+ return y, dx
141+ end
142+
143+ function DI. value_and_pullback! (
144+ f!,
145+ y,
146+ dx,
147+ backend:: AutoFastDifferentiation ,
148+ x,
149+ dy,
150+ extras:: FastDifferentiationTwoArgPullbackExtras ,
151+ )
152+ DI. pullback! (f!, y, dx, backend, x, dy, extras)
153+ f! (y, x)
154+ return y, dx
155+ end
156+
81157# # Derivative
82158
83159struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras
0 commit comments