1- struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
1+ struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
22 _sig:: Val{SIG}
33 cache:: Tcache
44 dy_backup:: DY
5- target_function:: F
65 args_to_zero:: NTuple{N, Bool}
76end
87
@@ -16,13 +15,9 @@ function DI.prepare_pullback_nokwarg(
1615 contexts:: Vararg{DI.Context, C}
1716 ) where {F, C}
1817 _sig = DI. signature (f!, y, backend, x, ty, contexts... ; strict)
19- target_function = function (f!, y, x, contexts... )
20- f! (y, x, contexts... )
21- return y
22- end
2318 config = get_config (backend)
2419 cache = prepare_pullback_cache (
25- target_function ,
20+ call_and_return ,
2621 f!,
2722 y,
2823 x,
@@ -32,14 +27,14 @@ function DI.prepare_pullback_nokwarg(
3227 dy_backup_after = zero_tangent (y)
3328 contexts_tup_false = map (_ -> false , contexts)
3429 args_to_zero = (
35- false , # target_function
30+ false , # call_and_return
3631 false , # f!
3732 false , # y
3833 true , # x
3934 contexts_tup_false... ,
4035 )
4136 prep = MooncakeTwoArgPullbackPrep (
42- _sig, cache, dy_backup_after, target_function, args_to_zero
37+ _sig, cache, dy_backup_after, args_to_zero
4338 )
4439 return prep
4540end
@@ -61,7 +56,7 @@ function DI.value_and_pullback(
6156 y_after, (_, _, _, dx) = value_and_pullback!! (
6257 prep. cache,
6358 dy_backup_after,
64- prep . target_function ,
59+ call_and_return ,
6560 f!,
6661 y,
6762 x,
@@ -87,7 +82,7 @@ function DI.value_and_pullback(
8782 y_after, (_, _, _, dx) = value_and_pullback!! (
8883 prep. cache,
8984 dy_backup_after,
90- prep . target_function ,
85+ call_and_return ,
9186 f!,
9287 y,
9388 x,
0 commit comments