11# # Pushforward
22
3- # TODO : needs friendly tangents support
4-
53struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
64 _sig:: Val{SIG}
75 cache:: Tcache
@@ -29,8 +27,13 @@ function DI.prepare_pushforward_nokwarg(
2927 map (DI. unwrap, contexts)... ;
3028 config,
3129 )
32- dx_righttype = zero_tangent (x)
33- dy_righttype = zero_tangent (y)
30+ if config. friendly_tangents
31+ dx_righttype = zero_tangent (x)
32+ dy_righttype = zero_tangent (y)
33+ else
34+ dx_righttype = nothing
35+ dy_righttype = nothing
36+ end
3437 df! = zero_tangent (f!)
3538 context_tangents = map (zero_tangent_unwrap, contexts)
3639 prep = MooncakeTwoArgPushforwardPrep (_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
@@ -49,7 +52,7 @@ function DI.value_and_pushforward(
4952 DI. check_prep (f!, y, prep, backend, x, tx, contexts... )
5053 ty = map (tx) do dx
5154 dx_righttype =
52- dx isa tangent_type (X ) ? dx : _copy_to_output !! (prep. dx_righttype, dx)
55+ isnothing (prep . dx_righttype ) ? dx : primal_to_tangent !! (prep. dx_righttype, dx)
5356 y_dual = zero_dual (y)
5457 value_and_derivative!! (
5558 prep. cache,
@@ -58,7 +61,11 @@ function DI.value_and_pushforward(
5861 Dual (x, dx_righttype),
5962 map (Dual_unwrap, contexts, prep. context_tangents)... ,
6063 )
61- dy = _copy_output (tangent (y_dual))
64+ if isnothing (prep. dx_righttype)
65+ dy = _copy_output (tangent (y_dual))
66+ else
67+ dy = tangent_to_primal!! (_copy_output (y), tangent (y_dual))
68+ end
6269 return dy
6370 end
6471 return y, ty
@@ -90,17 +97,17 @@ function DI.value_and_pushforward!(
9097 DI. check_prep (f!, y, prep, backend, x, tx, contexts... )
9198 foreach (tx, ty) do dx, dy
9299 dx_righttype =
93- dx isa tangent_type (X ) ? dx : _copy_to_output !! (prep. dx_righttype, dx)
100+ isnothing (prep . dx_righttype ) ? dx : primal_to_tangent !! (prep. dx_righttype, dx)
94101 dy_righttype =
95- dy isa tangent_type (Y ) ? dy : _copy_to_output !! (prep. dy_righttype, dy)
102+ isnothing (prep . dy_righttype ) ? dy : primal_to_tangent !! (prep. dy_righttype, dy)
96103 value_and_derivative!! (
97104 prep. cache,
98105 Dual (f!, prep. df!),
99106 Dual (y, dy_righttype),
100107 Dual (x, dx_righttype),
101108 map (Dual_unwrap, contexts, prep. context_tangents)... ,
102109 )
103- dy === dy_righttype || copyto ! (dy, dy_righttype)
110+ isnothing (prep . dy_righttype) || tangent_to_primal! ! (dy, dy_righttype)
104111 end
105112 return y, ty
106113end
0 commit comments