Skip to content

Commit 1e8df98

Browse files
changes from reviews, Docs
1 parent 13233e5 commit 1e8df98

3 files changed

Lines changed: 9 additions & 10 deletions

File tree

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
9595
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
9696
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
9797
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
98-
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.
98+
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) or a [Mooncake.jl](https://github.com/compintell/Mooncake.jl)-compatible backend.
9999

100100
## Implementations
101101

DifferentiationInterface/docs/src/faq/differentiability.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ There are, however, translation utilities:
111111
### Backend switch
112112

113113
Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
114-
Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object.
114+
Right now it only targets ForwardDiff.jl, ChainRulesCore.jl and Mooncake.jl but PRs are welcome to define Enzyme.jl rules for this object.

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:AbstractArray}
2-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Number}
1+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}}
32

43
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
54
primal_func = primal(dw)
@@ -14,7 +13,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
1413
end
1514

1615
# output is a scalar, so we can use the scalar pullback
17-
function pullback!!(dy)
16+
function pullback!!(dy::Number)
1817
tx = DI.pullback(f, backend, primal_x, (dy,))
1918
return NoRData(), only(tx)
2019
end
@@ -28,22 +27,22 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
2827
fdata_arg = fdata(x.dx)
2928
(; f, backend) = primal_func
3029
y = zero_fcodual(f(primal_x))
30+
# in case x is mutated in f calls
31+
cp_primal_x = copy(primal_x)
3132

3233
# output is a vector, so we need to use the vector pullback
3334
function pullback!!(dy::NoRData)
34-
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
35+
tx = DI.pullback(f, backend, cp_primal_x, (fdata(y.dx),))
3536
fdata_arg .+= only(tx)
3637
return NoRData(), dy
3738
end
3839

3940
# output is a scalar, so we can use the scalar pullback
40-
function pullback!!(dy)
41-
tx = DI.pullback(f, backend, primal_x, (dy,))
41+
function pullback!!(dy::Number)
42+
tx = DI.pullback(f, backend, cp_primal_x, (dy,))
4243
fdata_arg .+= only(tx)
4344
return NoRData(), NoRData()
4445
end
4546

46-
# in case x is mutated when passed into f
47-
x = CoDual(primal_x, x.dx)
4847
return y, pullback!!
4948
end

0 commit comments

Comments
 (0)