@@ -5,28 +5,45 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
55 primal_func = primal (dw)
66 primal_x = primal (x)
77 (; f, backend) = primal_func
8- y = f (primal_x)
8+ y = zero_fcodual ( f (primal_x) )
99
10+ # output is a vector, so we need to use the vector pullback
11+ function pullback!! (dy:: NoRData )
12+ tx = DI. pullback (f, backend, primal_x, (fdata (y. dx),))
13+ return NoRData (), only (tx)
14+ end
15+
16+ # output is a scalar, so we can use the scalar pullback
1017 function pullback!! (dy)
1118 tx = DI. pullback (f, backend, primal_x, (dy,))
12- return ( NoRData (), only (tx) )
19+ return NoRData (), only (tx)
1320 end
1421
15- return zero_fcodual (y) , pullback!!
22+ return y , pullback!!
1623end
1724
1825function Mooncake. rrule!! (dw:: CoDual{<:DI.DifferentiateWith} , x:: CoDual{<:AbstractArray} )
1926 primal_func = primal (dw)
2027 primal_x = primal (x)
2128 fdata_arg = fdata (x. dx)
2229 (; f, backend) = primal_func
23- y = f (primal_x)
30+ y = zero_fcodual (f (primal_x))
31+
32+ # output is a vector, so we need to use the vector pullback
33+ function pullback!! (dy:: NoRData )
34+ tx = DI. pullback (f, backend, primal_x, (fdata (y. dx),))
35+ fdata_arg .+ = only (tx)
36+ return NoRData (), dy
37+ end
2438
39+ # output is a scalar, so we can use the scalar pullback
2540 function pullback!! (dy)
2641 tx = DI. pullback (f, backend, primal_x, (dy,))
2742 fdata_arg .+ = only (tx)
28- return ( NoRData (), NoRData () )
43+ return NoRData (), NoRData ()
2944 end
3045
31- return zero_fcodual (y), pullback!!
46+ # in case x is mutated when passed into f
47+ x = CoDual (primal_x, x. dx)
48+ return y, pullback!!
3249end
0 commit comments