11# # Pushforward
22
3- struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT , CT} <: DI.PushforwardPrep{SIG}
3+ struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT0, FT, YT , CT} <: DI.PushforwardPrep{SIG}
44 _sig:: Val{SIG}
55 cache:: Tcache
6+ dcall:: FT0
67 df!:: FT
8+ dy:: YT
79 context_tangents:: CT
810end
911
@@ -26,9 +28,11 @@ function DI.prepare_pushforward_nokwarg(
2628 map (DI. unwrap, contexts)... ;
2729 config
2830 )
31+ dcall = zero_tangent_or_primal (call_and_return, backend)
2932 df! = zero_tangent_or_primal (f!, backend)
33+ dy = zero_tangent_or_primal (y, backend)
3034 context_tangents = map (zero_tangent_unwrap, contexts)
31- prep = MooncakeTwoArgPushforwardPrep (_sig, cache, df!, context_tangents)
35+ prep = MooncakeTwoArgPushforwardPrep (_sig, cache, dcall, df!, dy , context_tangents)
3236 return prep
3337end
3438
@@ -43,13 +47,11 @@ function DI.value_and_pushforward(
4347 ) where {F, C, X}
4448 DI. check_prep (f!, y, prep, backend, x, tx, contexts... )
4549 ty = map (tx) do dx
46- dy = zero_tangent_or_primal (y, backend) # TODO : remove allocation?
47- dcall = zero_tangent_or_primal (call_and_return, backend)
4850 _, new_dy = value_and_derivative!! (
4951 prep. cache,
50- (call_and_return, dcall),
52+ (call_and_return, prep . dcall),
5153 (f!, prep. df!),
52- (y, dy),
54+ (y, prep . dy),
5355 (x, dx),
5456 map (first_unwrap, contexts, prep. context_tangents)... ,
5557 )
@@ -83,10 +85,9 @@ function DI.value_and_pushforward!(
8385 ) where {F, C, X, Y}
8486 DI. check_prep (f!, y, prep, backend, x, tx, contexts... )
8587 foreach (tx, ty) do dx, dy
86- dcall = zero_tangent_or_primal (call_and_return, backend)
8788 _, new_dy = value_and_derivative!! (
8889 prep. cache,
89- (call_and_return, dcall),
90+ (call_and_return, prep . dcall),
9091 (f!, prep. df!),
9192 (y, dy),
9293 (x, dx),
0 commit comments