Skip to content

Commit 7a07127

Browse files
changes from reviews-1
1 parent 233c312 commit 7a07127

5 files changed

Lines changed: 14 additions & 7 deletions

File tree

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
9595
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
9696
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
9797
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
98-
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/compintell/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
98+
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
9999

100100
## Implementations
101101

@@ -177,7 +177,7 @@ For all operators, preparation generates an [executable function](https://docs.s
177177

178178
### Mooncake
179179

180-
For `pullback`, preparation [builds the reverse rule](https://github.com/compintell/Mooncake.jl?tab=readme-ov-file#how-it-works) of the function.
180+
For `pullback`, preparation [builds the reverse rule](https://github.com/chalk-lab/Mooncake.jl?tab=readme-ov-file#how-it-works) of the function.
181181

182182
### Tracker
183183

DifferentiationInterface/docs/src/faq/differentiability.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ Note that its rule writing is very different from ChainRulesCore.jl due to the p
8484

8585
### Mooncake
8686

87-
[Mooncake.jl](https://github.com/compintell/Mooncake.jl) is a recent package which also handles a large subset of all Julia programs out-of-the-box.
87+
[Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl) is a recent package which also handles a large subset of all Julia programs out-of-the-box.
8888

89-
Its [rule system](https://compintell.github.io/Mooncake.jl/dev/understanding_mooncake/rule_system/) is less expressive than that of Enzyme.jl, which might make it easier to start with.
89+
Its [rule system](https://chalk-lab.github.io/Mooncake.jl/dev/understanding_mooncake/rule_system/) is less expressive than that of Enzyme.jl, which might make it easier to start with.
9090

9191
## A rule mayhem?
9292

@@ -106,9 +106,9 @@ There are, however, translation utilities:
106106

107107
- from ChainRulesCore.jl to ForwardDiff.jl with [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl)
108108
- from ChainRulesCore.jl to Enzyme.jl with [`Enzyme.@import_rrule`](https://enzymead.github.io/Enzyme.jl/stable/api/#Enzyme.@import_rrule-Tuple)
109-
- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://compintell.github.io/Mooncake.jl/dev/utilities/tools_for_rules/#Using-ChainRules.jl)
109+
- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://chalk-lab.github.io/Mooncake.jl/dev/utilities/tools_for_rules/#Using-ChainRules.jl)
110110

111111
### Backend switch
112112

113113
Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
114-
Right now, it only targets ForwardDiff.jl, Mooncake.jl, ChainRules.jl-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object.
114+
Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object.

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using Mooncake:
1212
value_and_gradient!!,
1313
value_and_pullback!!,
1414
zero_tangent,
15+
rdata_type,
1516
@is_primitive,
1617
zero_fcodual,
1718
MinimalCtx,

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
99
# output is a vector, so we need to use the vector pullback
1010
function pullback_array!!(dy::NoRData)
1111
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
12+
@assert only(tx) isa rdata_type(typeof(x))
1213
return NoRData(), only(tx)
1314
end
1415

1516
# output is a scalar, so we can use the scalar pullback
1617
function pullback_scalar!!(dy::Number)
1718
tx = DI.pullback(f, backend, primal_x, (dy,))
19+
@assert only(tx) isa rdata_type(typeof(x))
1820
return NoRData(), only(tx)
1921
end
2022

DifferentiationInterface/src/misc/differentiate_with.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be
1313
1414
!!! warning
1515
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
16-
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
16+
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
1717
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
1818
19+
!!! warning
20+
When using Mooncake as a substitute backend via `DifferentiateWith(f, AutoMooncake())`. The function `f` must not close over any active data.
21+
As of now, we cannot differentiate with respect to parameters stored inside `f`.
22+
1923
# Fields
2024
2125
- `f`: the function in question, with signature `f(x)`

0 commit comments

Comments
 (0)