Skip to content

Commit afdddd4

Browse files
changes from reviews - 2
1 parent 1e8df98 commit afdddd4

3 files changed

Lines changed: 10 additions & 12 deletions

File tree

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
9595
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
9696
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
9797
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
98-
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) or a [Mooncake.jl](https://github.com/compintell/Mooncake.jl)-compatible backend.
98+
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/compintell/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
9999

100100
## Implementations
101101

DifferentiationInterface/docs/src/faq/differentiability.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ There are, however, translation utilities:
111111
### Backend switch
112112

113113
Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
114-
Right now it only targets ForwardDiff.jl, ChainRulesCore.jl and Mooncake.jl but PRs are welcome to define Enzyme.jl rules for this object.
114+
Right now, it only targets ForwardDiff.jl, Mooncake.jl, ChainRules.jl-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object.

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
77
y = zero_fcodual(f(primal_x))
88

99
# output is a vector, so we need to use the vector pullback
10-
function pullback!!(dy::NoRData)
10+
function pullback_array!!(dy::NoRData)
1111
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
1212
return NoRData(), only(tx)
1313
end
1414

1515
# output is a scalar, so we can use the scalar pullback
16-
function pullback!!(dy::Number)
16+
function pullback_scalar!!(dy::Number)
1717
tx = DI.pullback(f, backend, primal_x, (dy,))
1818
return NoRData(), only(tx)
1919
end
2020

21-
return y, pullback!!
21+
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!
2222
end
2323

2424
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
@@ -27,22 +27,20 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
2727
fdata_arg = fdata(x.dx)
2828
(; f, backend) = primal_func
2929
y = zero_fcodual(f(primal_x))
30-
# in case x is mutated in f calls
31-
cp_primal_x = copy(primal_x)
3230

3331
# output is a vector, so we need to use the vector pullback
34-
function pullback!!(dy::NoRData)
35-
tx = DI.pullback(f, backend, cp_primal_x, (fdata(y.dx),))
32+
function pullback_array!!(dy::NoRData)
33+
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
3634
fdata_arg .+= only(tx)
3735
return NoRData(), dy
3836
end
3937

4038
# output is a scalar, so we can use the scalar pullback
41-
function pullback!!(dy::Number)
42-
tx = DI.pullback(f, backend, cp_primal_x, (dy,))
39+
function pullback_scalar!!(dy::Number)
40+
tx = DI.pullback(f, backend, primal_x, (dy,))
4341
fdata_arg .+= only(tx)
4442
return NoRData(), NoRData()
4543
end
4644

47-
return y, pullback!!
45+
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!
4846
end

0 commit comments

Comments
 (0)