Skip to content

Commit 1307890

Browse files
gdallewilltebbutt
andauthored
[BREAKING] Transition from Tapir to Mooncake (#500)
* [BREAKING] Transition from Tapir to Mooncake * Multi-arg * Reactivate CI * Apply suggestions from code review Co-authored-by: Will Tebbutt <wt0881@my.bristol.ac.uk> --------- Co-authored-by: Will Tebbutt <wt0881@my.bristol.ac.uk>
1 parent 987ca87 commit 1307890

13 files changed

Lines changed: 302 additions & 191 deletions

File tree

.github/workflows/Test.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ jobs:
4040
- Back/FiniteDiff
4141
- Back/FiniteDifferences
4242
- Back/ForwardDiff
43+
- Back/Mooncake
4344
- Back/PolyesterForwardDiff
4445
- Back/ReverseDiff
4546
- Back/SecondOrder
4647
- Back/Symbolics
47-
- Back/Tapir
4848
- Back/Tracker
4949
- Back/Zygote
5050
- Misc/DifferentiateWith
@@ -67,14 +67,14 @@ jobs:
6767
group: Back/FiniteDiff
6868
- version: "lts"
6969
group: Back/FastDifferentiation
70+
- version: "lts"
71+
group: Back/Mooncake
7072
- version: "lts"
7173
group: Back/PolyesterForwardDiff
7274
- version: "lts"
7375
group: Back/SecondOrder
7476
- version: "lts"
7577
group: Back/Symbolics
76-
- version: "lts"
77-
group: Back/Tapir
7878
- version: "lts"
7979
group: Misc/SparsityDetector
8080
- version: "lts"
@@ -89,7 +89,7 @@ jobs:
8989
- version: "pre"
9090
group: Back/Enzyme
9191
- version: "pre"
92-
group: Back/Tapir
92+
group: Back/Mooncake
9393
- version: "pre"
9494
group: Back/SecondOrder
9595
- version: "pre"

DifferentiationInterface/Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
1717
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1818
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1919
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
20+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2021
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2122
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2223
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2324
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
2425
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
25-
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
2626
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2727
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2828

@@ -34,17 +34,17 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3434
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3535
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3636
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
37+
DifferentiationInterfaceMooncakeExt = "Mooncake"
3738
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
3839
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
3940
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
4041
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
4142
DifferentiationInterfaceSymbolicsExt = "Symbolics"
42-
DifferentiationInterfaceTapirExt = "Tapir"
4343
DifferentiationInterfaceTrackerExt = "Tracker"
4444
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
4545

4646
[compat]
47-
ADTypes = "1.7.0"
47+
ADTypes = "1.9.0"
4848
ChainRulesCore = "1.23.0"
4949
Compat = "3.46,4.2"
5050
Diffractor = "=0.2.6"
@@ -54,14 +54,14 @@ FiniteDiff = "2.23.1"
5454
FiniteDifferences = "0.12.31"
5555
ForwardDiff = "0.10.36"
5656
LinearAlgebra = "<0.0.1,1"
57+
Mooncake = "0.4.0"
5758
PackageExtensionCompat = "1.0.2"
5859
PolyesterForwardDiff = "0.1.1"
5960
ReverseDiff = "1.15.1"
6061
SparseArrays = "<0.0.1,1"
6162
SparseConnectivityTracer = "0.5.0,0.6"
6263
SparseMatrixColorings = "0.4.0"
6364
Symbolics = "5.27.1, 6"
64-
Tapir = "0.2.48"
6565
Tracker = "0.2.33"
6666
Zygote = "0.6.69"
6767
julia = "1.6"
@@ -81,6 +81,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8181
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
8282
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
8383
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
84+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
8485
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8586
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
8687
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -91,7 +92,6 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
9192
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
9293
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
9394
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
94-
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
9595
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9696
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9797
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

DifferentiationInterface/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ We support the following backends defined by [ADTypes.jl](https://github.com/Sci
3737
- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl)
3838
- [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl)
3939
- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
40+
- [Mooncake.jl](https://github.com/compintell/Mooncake.jl)
4041
- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl)
4142
- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl)
4243
- [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl)
43-
- [Tapir.jl](https://github.com/withbayes/Tapir.jl)
4444
- [Tracker.jl](https://github.com/FluxML/Tracker.jl)
4545
- [Zygote.jl](https://github.com/FluxML/Zygote.jl)
4646

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
1111
- [`AutoFiniteDiff`](@extref ADTypes.AutoFiniteDiff)
1212
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
1313
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
14+
- [`AutoMooncake`](@extref ADTypes.AutoMooncake)
1415
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
1516
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
1617
- [`AutoSymbolics`](@extref ADTypes.AutoSymbolics)
17-
- [`AutoTapir`](@extref ADTypes.AutoTapir)
1818
- [`AutoTracker`](@extref ADTypes.AutoTracker)
1919
- [`AutoZygote`](@extref ADTypes.AutoZygote)
2020

@@ -55,10 +55,10 @@ In practice, many AD backends have custom implementations for high-level operato
5555
| `AutoFiniteDiff` | 🔀 | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
5656
| `AutoFiniteDifferences` | 🔀 | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
5757
| `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
58+
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
5859
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
5960
| `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
6061
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
61-
| `AutoTapir` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
6262
| `AutoTracker` | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
6363
| `AutoZygote` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | 🔀 | ❌ |
6464

@@ -144,9 +144,9 @@ For all operators, preparation generates an [executable function](https://docs.s
144144
!!! warning
145145
Preparation can be very slow for symbolic AD.
146146

147-
### Tapir
147+
### Mooncake
148148

149-
For `pullback`, preparation [builds the reverse rule](https://github.com/withbayes/Tapir.jl?tab=readme-ov-file#how-it-works) of the function.
149+
For `pullback`, preparation [builds the reverse rule](https://github.com/compintell/Mooncake.jl?tab=readme-ov-file#how-it-works) of the function.
150150

151151
### Tracker
152152

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
module DifferentiationInterfaceMooncakeExt
2+
3+
using ADTypes: ADTypes, AutoMooncake
4+
import DifferentiationInterface as DI
5+
using DifferentiationInterface: Context, PullbackPrep, unwrap
6+
using Mooncake:
7+
CoDual,
8+
Config,
9+
NoRData,
10+
NoTangent,
11+
build_rrule,
12+
fdata,
13+
get_interpreter,
14+
increment!!,
15+
primal,
16+
rdata,
17+
set_to_zero!!,
18+
tangent,
19+
tangent_type,
20+
value_and_pullback!!,
21+
zero_codual,
22+
zero_fcodual,
23+
zero_tangent,
24+
__value_and_pullback!!
25+
26+
DI.check_available(::AutoMooncake) = true
27+
28+
get_config(::AutoMooncake{Nothing}) = Config()
29+
get_config(backend::AutoMooncake{<:Config}) = backend.config
30+
31+
include("onearg.jl")
32+
include("twoarg.jl")
33+
34+
end
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
struct MooncakeOneArgPullbackPrep{Y,R} <: PullbackPrep
2+
y_prototype::Y
3+
rrule::R
4+
end
5+
6+
function DI.prepare_pullback(
7+
f, backend::AutoMooncake, x, ty::DI.Tangents, contexts::Vararg{Context,C}
8+
) where {C}
9+
y = f(x, map(unwrap, contexts)...)
10+
config = get_config(backend)
11+
rrule = build_rrule(
12+
get_interpreter(),
13+
Tuple{typeof(f),typeof(x),typeof.(map(unwrap, contexts))...};
14+
debug_mode=config.debug_mode,
15+
silence_debug_messages=config.silence_debug_messages,
16+
)
17+
prep = MooncakeOneArgPullbackPrep(y, rrule)
18+
DI.value_and_pullback(f, prep, backend, x, ty, contexts...) # warm up
19+
return prep
20+
end
21+
22+
function DI.value_and_pullback(
23+
f,
24+
prep::MooncakeOneArgPullbackPrep{Y},
25+
::AutoMooncake,
26+
x,
27+
ty::DI.Tangents{1},
28+
contexts::Vararg{Context,C},
29+
) where {Y,C}
30+
dy = only(ty)
31+
dy_righttype = convert(tangent_type(Y), dy)
32+
new_y, (_, new_dx) = value_and_pullback!!(
33+
prep.rrule, dy_righttype, f, x, map(unwrap, contexts)...
34+
)
35+
return new_y, DI.Tangents(new_dx)
36+
end
37+
38+
function DI.value_and_pullback!(
39+
f,
40+
prep::MooncakeOneArgPullbackPrep{Y},
41+
tx::DI.Tangents,
42+
::AutoMooncake,
43+
x,
44+
ty::DI.Tangents{1},
45+
contexts::Vararg{Context,C},
46+
) where {Y,C}
47+
dx, dy = only(tx), only(ty)
48+
dy_righttype = convert(tangent_type(Y), dy)
49+
dx_righttype = set_to_zero!!(convert(tangent_type(typeof(x)), dx))
50+
contexts_coduals = map(zero_fcodual unwrap, contexts)
51+
y, (_, new_dx) = __value_and_pullback!!(
52+
prep.rrule,
53+
dy_righttype,
54+
zero_codual(f),
55+
CoDual(x, dx_righttype),
56+
contexts_coduals...,
57+
)
58+
copyto!(dx, new_dx)
59+
return y, tx
60+
end
61+
62+
function DI.value_and_pullback(
63+
f,
64+
prep::MooncakeOneArgPullbackPrep,
65+
backend::AutoMooncake,
66+
x,
67+
ty::DI.Tangents,
68+
contexts::Vararg{Context,C},
69+
) where {C}
70+
ys_and_dxs = map(ty.d) do dy
71+
y, tx = DI.value_and_pullback(f, prep, backend, x, DI.Tangents(dy), contexts...)
72+
y, only(tx)
73+
end
74+
y = first(ys_and_dxs[1])
75+
dxs = last.(ys_and_dxs)
76+
return y, DI.Tangents(dxs...)
77+
end
78+
79+
function DI.pullback(
80+
f,
81+
prep::MooncakeOneArgPullbackPrep,
82+
backend::AutoMooncake,
83+
x,
84+
ty::DI.Tangents,
85+
contexts::Vararg{Context,C},
86+
) where {C}
87+
return DI.value_and_pullback(f, prep, backend, x, ty, contexts...)[2]
88+
end
89+
90+
function DI.pullback!(
91+
f,
92+
tx::DI.Tangents,
93+
prep::MooncakeOneArgPullbackPrep,
94+
backend::AutoMooncake,
95+
x,
96+
ty::DI.Tangents,
97+
contexts::Vararg{Context,C},
98+
) where {C}
99+
return DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2]
100+
end

0 commit comments

Comments
 (0)