1- struct MooncakeTwoArgPullbackPrep{R} <: PullbackPrep
1+ struct MooncakeTwoArgPullbackPrep{R,F,Y,DX,DY } <: PullbackPrep
22 rrule:: R
3+ df!:: F
4+ y_copy:: Y
5+ dx_righttype:: DX
6+ dy_righttype:: DY
7+ dy_righttype_after:: DY
38end
49
510function DI. prepare_pullback (
@@ -12,7 +17,14 @@ function DI.prepare_pullback(
1217 debug_mode= config. debug_mode,
1318 silence_debug_messages= config. silence_debug_messages,
1419 )
15- prep = MooncakeTwoArgPullbackPrep (rrule)
20+ df! = zero_tangent (f!)
21+ y_copy = copy (y)
22+ dx_righttype = zero_tangent (x)
23+ dy_righttype = zero_tangent (y)
24+ dy_righttype_after = zero_tangent (y)
25+ prep = MooncakeTwoArgPullbackPrep (
26+ rrule, df!, y_copy, dx_righttype, dy_righttype, dy_righttype_after
27+ )
1628 DI. value_and_pullback (f!, y, prep, backend, x, ty, contexts... ) # warm up
1729 return prep
1830end
@@ -27,46 +39,38 @@ function DI.value_and_pullback(
2739 contexts:: Vararg{Context,C} ,
2840) where {C}
2941 dy = only (ty)
30- dy_righttype = convert (tangent_type (typeof (y)), copy (dy))
31- dx_righttype = zero_tangent (x)
3242
33- # We want the VJP, not VJP + dx, so I'm going to zero-out `dx`. `set_to_zero!!` has the advantage
34- # that it will also replace any immutable components of `dx` to zero.
35- dx_righttype = set_to_zero!! (dx_righttype)
43+ # Set all tangent storage to zero.
44+ df! = set_to_zero!! (prep. df!)
45+ dx_righttype = set_to_zero!! (prep. dx_righttype)
46+ dy_righttype = set_to_zero!! (prep. dy_righttype)
3647
37- # We want `dy` to correspond to the cotangent of `y` _after_
38- # running the forwards-pass, so I'm going to take a copy, and zero-out the original.
39- dy_righttype_backup = copy (dy_righttype)
40- dy_righttype = set_to_zero!! (dy_righttype)
41- contexts_coduals = map (zero_fcodual ∘ unwrap, contexts)
42-
43- # Mutate a copy of `y`, so that we can run the reverse-pass later on.
44- y_copy = copy (y)
48+ # Prepare cotangent to add after the forward pass.
49+ dy_righttype_after = copyto! (prep. dy_righttype_after, dy)
4550
46- # In case `f!` is a closure
47- df! = zero_tangent (f!)
51+ contexts_coduals = map (zero_fcodual ∘ unwrap, contexts)
4852
49- # Run the forwards- pass.
53+ # Run the forward pass
5054 out, pb!! = prep. rrule (
5155 CoDual (f!, fdata (df!)),
52- CoDual (y_copy, fdata (dy_righttype)),
56+ CoDual (prep . y_copy, fdata (dy_righttype)),
5357 CoDual (x, fdata (dx_righttype)),
5458 contexts_coduals... ,
5559 )
5660
5761 # Verify that the output is non-differentiable.
5862 @assert primal (out) === nothing
5963
60- # Set the cotangent of `y` to be equal to the requested value .
61- dy_righttype = increment!! (dy_righttype, dy_righttype_backup )
64+ # Increment the desired cotangent dy .
65+ dy_righttype = increment!! (dy_righttype, dy_righttype_after )
6266
63- # Record the state of `y` before running the reverse- pass.
64- y = copyto! (y, y_copy)
67+ # Record the state of y before running the reverse pass.
68+ y = copyto! (y, prep . y_copy)
6569
66- # Run the reverse- pass.
70+ # Run the reverse pass.
6771 _, _, new_dx = pb!! (NoRData ())
6872
69- return y, (tangent (fdata (dx_righttype), new_dx),)
73+ return y, (tangent (copy ( fdata (dx_righttype)) , new_dx),) # TODO : remove this allocation in `value_and_pullback!`
7074end
7175
7276function DI. value_and_pullback (
0 commit comments