Skip to content

Commit 13233e5

Browse files
too easy
1 parent 2e95299 commit 13233e5

1 file changed

Lines changed: 23 additions & 6 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,45 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
55
primal_func = primal(dw)
66
primal_x = primal(x)
77
(; f, backend) = primal_func
8-
y = f(primal_x)
8+
y = zero_fcodual(f(primal_x))
99

10+
# output is a vector, so we need to use the vector pullback
11+
function pullback!!(dy::NoRData)
12+
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
13+
return NoRData(), only(tx)
14+
end
15+
16+
# output is a scalar, so we can use the scalar pullback
1017
function pullback!!(dy)
1118
tx = DI.pullback(f, backend, primal_x, (dy,))
12-
return (NoRData(), only(tx))
19+
return NoRData(), only(tx)
1320
end
1421

15-
return zero_fcodual(y), pullback!!
22+
return y, pullback!!
1623
end
1724

1825
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
1926
primal_func = primal(dw)
2027
primal_x = primal(x)
2128
fdata_arg = fdata(x.dx)
2229
(; f, backend) = primal_func
23-
y = f(primal_x)
30+
y = zero_fcodual(f(primal_x))
31+
32+
# output is a vector, so we need to use the vector pullback
33+
function pullback!!(dy::NoRData)
34+
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
35+
fdata_arg .+= only(tx)
36+
return NoRData(), dy
37+
end
2438

39+
# output is a scalar, so we can use the scalar pullback
2540
function pullback!!(dy)
2641
tx = DI.pullback(f, backend, primal_x, (dy,))
2742
fdata_arg .+= only(tx)
28-
return (NoRData(), NoRData())
43+
return NoRData(), NoRData()
2944
end
3045

31-
return zero_fcodual(y), pullback!!
46+
# in case x is mutated when passed into f
47+
x = CoDual(primal_x, x.dx)
48+
return y, pullback!!
3249
end

0 commit comments

Comments
 (0)