Skip to content

Commit 1c19d56

Browse files
committed
feat: support forward-mode Mooncake [experimental]
1 parent eeb4c86 commit 1c19d56

6 files changed

Lines changed: 227 additions & 9 deletions

File tree

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
1212
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
1313
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
1414
- [`AutoGTPSA`](@extref ADTypes.AutoGTPSA)
15-
- [`AutoMooncake`](@extref ADTypes.AutoMooncake)
15+
- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake)
1616
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
1717
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
1818
- [`AutoSymbolics`](@extref ADTypes.AutoSymbolics)
@@ -48,6 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato
4848
| `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
4949
| `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
5050
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
51+
| `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
5152
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
5253
| `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
5354
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
@@ -68,6 +69,7 @@ Moreover, each context type is supported by a specific subset of backends:
6869
| `AutoForwardDiff` |||
6970
| `AutoGTPSA` |||
7071
| `AutoMooncake` |||
72+
| `AutoMooncakeForward` |||
7173
| `AutoPolyesterForwardDiff` |||
7274
| `AutoReverseDiff` |||
7375
| `AutoSymbolics` |||
Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
11
module DifferentiationInterfaceMooncakeExt
22

3-
using ADTypes: ADTypes, AutoMooncake
3+
using ADTypes: ADTypes, AutoMooncake, AutoMooncakeForward
44
import DifferentiationInterface as DI
55
using Mooncake:
66
CoDual,
77
Config,
8+
Dual,
9+
prepare_derivative_cache,
810
prepare_gradient_cache,
911
prepare_pullback_cache,
12+
primal,
13+
tangent,
1014
tangent_type,
15+
value_and_derivative!!,
1116
value_and_gradient!!,
1217
value_and_pullback!!,
18+
zero_dual,
1319
zero_tangent,
1420
_copy_output,
1521
_copy_to_output!!
1622

17-
DI.check_available(::AutoMooncake) = true
23+
const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}}
1824

19-
get_config(::AutoMooncake{Nothing}) = Config()
20-
get_config(backend::AutoMooncake{<:Config}) = backend.config
25+
DI.check_available(::AnyAutoMooncake) = true
2126

22-
# tangents need to be copied before returning, otherwise they are still aliased in the cache
23-
mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x)
24-
mycopy(x) = deepcopy(x)
27+
get_config(::AnyAutoMooncake{Nothing}) = Config()
28+
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
2529

2630
include("onearg.jl")
2731
include("twoarg.jl")
32+
include("forward_onearg.jl")
33+
include("forward_twoarg.jl")
2834

2935
end
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
## Pushforward
2+
3+
struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG}
4+
_sig::Val{SIG}
5+
cache::Tcache
6+
dx_righttype::DX
7+
end
8+
9+
function DI.prepare_pushforward_nokwarg(
10+
strict::Val,
11+
f::F,
12+
backend::AutoMooncakeForward,
13+
x,
14+
tx::NTuple,
15+
contexts::Vararg{DI.Context,C};
16+
) where {F,C}
17+
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
18+
config = get_config(backend)
19+
# TODO: silence_debug_messages
20+
cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config.debug_mode)
21+
dx_righttype = zero_tangent(x)
22+
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype)
23+
return prep
24+
end
25+
26+
function DI.value_and_pushforward(
27+
f::F,
28+
prep::MooncakeOneArgPushforwardPrep,
29+
backend::AutoMooncakeForward,
30+
x::X,
31+
tx::NTuple,
32+
contexts::Vararg{DI.Context,C};
33+
) where {F,C,X}
34+
DI.check_prep(f, prep, backend, x, tx, contexts...)
35+
ys_and_ty = map(tx) do dx
36+
dx_righttype =
37+
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
38+
y_dual = value_and_derivative!!(
39+
prep.cache,
40+
zero_dual(f),
41+
Dual(x, dx_righttype),
42+
map(zero_dual DI.unwrap, contexts)...,
43+
)
44+
y = primal(y_dual)
45+
dy = _copy_output(tangent(y_dual))
46+
return y, dy
47+
end
48+
y = first(ys_and_ty[1])
49+
ty = last.(ys_and_ty)
50+
return y, ty
51+
end
52+
53+
function DI.pushforward(
54+
f::F,
55+
prep::MooncakeOneArgPushforwardPrep,
56+
backend::AutoMooncakeForward,
57+
x,
58+
tx::NTuple,
59+
contexts::Vararg{DI.Context,C};
60+
) where {F,C}
61+
DI.check_prep(f, prep, backend, x, tx, contexts...)
62+
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
63+
end
64+
65+
function DI.value_and_pushforward!(
66+
f::F,
67+
ty::NTuple,
68+
prep::MooncakeOneArgPushforwardPrep,
69+
backend::AutoMooncakeForward,
70+
x,
71+
tx::NTuple,
72+
contexts::Vararg{DI.Context,C};
73+
) where {F,C}
74+
DI.check_prep(f, prep, backend, x, tx, contexts...)
75+
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
76+
foreach(copyto!, ty, new_ty)
77+
return y, ty
78+
end
79+
80+
function DI.pushforward!(
81+
f::F,
82+
ty::NTuple,
83+
prep::MooncakeOneArgPushforwardPrep,
84+
backend::AutoMooncakeForward,
85+
x,
86+
tx::NTuple,
87+
contexts::Vararg{DI.Context,C};
88+
) where {F,C}
89+
DI.check_prep(f, prep, backend, x, tx, contexts...)
90+
DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)
91+
return ty
92+
end
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
## Pushforward
2+
3+
struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG}
4+
_sig::Val{SIG}
5+
cache::Tcache
6+
dx_righttype::DX
7+
dy_righttype::DY
8+
end
9+
10+
function DI.prepare_pushforward_nokwarg(
11+
strict::Val,
12+
f!::F,
13+
y,
14+
backend::AutoMooncakeForward,
15+
x,
16+
tx::NTuple,
17+
contexts::Vararg{DI.Context,C};
18+
) where {F,C}
19+
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
20+
config = get_config(backend)
21+
# TODO: silence_debug_messages
22+
cache = prepare_derivative_cache(
23+
f!, y, x, map(DI.unwrap, contexts)...; config.debug_mode
24+
)
25+
dx_righttype = zero_tangent(x)
26+
dy_righttype = zero_tangent(y)
27+
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype)
28+
return prep
29+
end
30+
31+
function DI.value_and_pushforward(
32+
f!::F,
33+
y,
34+
prep::MooncakeTwoArgPushforwardPrep,
35+
backend::AutoMooncakeForward,
36+
x::X,
37+
tx::NTuple,
38+
contexts::Vararg{DI.Context,C};
39+
) where {F,C,X}
40+
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
41+
ty = map(tx) do dx
42+
dx_righttype =
43+
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
44+
y_dual = zero_dual(y)
45+
value_and_derivative!!(
46+
prep.cache,
47+
zero_dual(f!),
48+
y_dual,
49+
Dual(x, dx_righttype),
50+
map(zero_dual DI.unwrap, contexts)...,
51+
)
52+
dy = _copy_output(tangent(y_dual))
53+
return dy
54+
end
55+
return y, ty
56+
end
57+
58+
function DI.pushforward(
59+
f!::F,
60+
y,
61+
prep::MooncakeOneArgPushforwardPrep,
62+
backend::AutoMooncakeForward,
63+
x,
64+
tx::NTuple,
65+
contexts::Vararg{DI.Context,C};
66+
) where {F,C}
67+
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
68+
return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2]
69+
end
70+
71+
function DI.value_and_pushforward!(
72+
f!::F,
73+
y::Y,
74+
ty::NTuple,
75+
prep::MooncakeOneArgPushforwardPrep,
76+
backend::AutoMooncakeForward,
77+
x::X,
78+
tx::NTuple,
79+
contexts::Vararg{DI.Context,C};
80+
) where {F,C,X,Y}
81+
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
82+
foreach(tx, ty) do dx, dy
83+
dx_righttype =
84+
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
85+
dy_righttype =
86+
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
87+
value_and_derivative!!(
88+
prep.cache,
89+
zero_dual(f!),
90+
Dual(y, dy_righttype),
91+
Dual(x, dx_righttype),
92+
map(zero_dual DI.unwrap, contexts)...,
93+
)
94+
dy === dy_righttype || copyto!(dy, dy_righttype)
95+
end
96+
return y, ty
97+
end
98+
99+
function DI.pushforward!(
100+
f!::F,
101+
y,
102+
ty::NTuple,
103+
prep::MooncakeOneArgPushforwardPrep,
104+
backend::AutoMooncakeForward,
105+
x,
106+
tx::NTuple,
107+
contexts::Vararg{DI.Context,C};
108+
) where {F,C}
109+
DI.check_prep(f!, y, ty, prep, backend, x, tx, contexts...)
110+
DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)
111+
return ty
112+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using ADTypes:
2828
AutoForwardDiff,
2929
AutoGTPSA,
3030
AutoMooncake,
31+
AutoMooncakeForward,
3132
AutoPolyesterForwardDiff,
3233
AutoReverseDiff,
3334
AutoSymbolics,
@@ -115,6 +116,7 @@ export AutoFiniteDifferences
115116
export AutoForwardDiff
116117
export AutoGTPSA
117118
export AutoMooncake
119+
export AutoMooncakeForward
118120
export AutoPolyesterForwardDiff
119121
export AutoReverseDiff
120122
export AutoSymbolics

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ check_no_implicit_imports(DifferentiationInterface)
1010

1111
LOGGING = get(ENV, "CI", "false") == "false"
1212

13-
backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]
13+
backends = [
14+
AutoMooncake(; config=nothing),
15+
AutoMooncake(; config=Mooncake.Config()),
16+
AutoMooncakeForward(; config=nothing);
17+
]
1418

1519
for backend in backends
1620
@test check_available(backend)

0 commit comments

Comments
 (0)