Skip to content

Commit 2f06fb1

Browse files
authored
Implement contexts for more backends (#492)
* Rename extras to prep * Typos * Remove map fix * Allow context for FiniteDifferences and Tracker, improve them for ForwardDiff and Zygote * Revert * Rename * Rewrap * With contexts * Fixes * Better Zygote hvp * Fixes * Fix
1 parent 932913a commit 2f06fb1

19 files changed

Lines changed: 659 additions & 271 deletions

File tree

DifferentiationInterface/docs/src/explanation/advanced.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Another option would be creating a closure, but that is sometimes undesirable.
1212

1313
!!! warning
1414
This feature is still experimental, and will likely not be supported by all backends.
15-
At the moment, it only works with ForwardDiff, Zygote and Enzyme.
15+
At the moment, it only works with certain backends, among which ForwardDiff, Zygote and Enzyme.
1616

1717
### Types of contexts
1818

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ using ChainRulesCore:
1212
using Compat
1313
import DifferentiationInterface as DI
1414
using DifferentiationInterface:
15-
DifferentiateWith, NoPullbackPrep, NoPushforwardPrep, PullbackPrep, Tangents
15+
Constant,
16+
DifferentiateWith,
17+
NoPullbackPrep,
18+
NoPushforwardPrep,
19+
PullbackPrep,
20+
Tangents,
21+
unwrap
1622

1723
ruleconfig(backend::AutoChainRules) = backend.ruleconfig
1824

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: PullbackPrep
55
pb::PB
66
end
77

8-
DI.prepare_pullback(f, ::AutoReverseChainRules, x, ty::Tangents) = NoPullbackPrep()
8+
function DI.prepare_pullback(f, ::AutoReverseChainRules, x, ty::Tangents)
9+
return NoPullbackPrep()
10+
end
911

1012
function DI.prepare_pullback_same_point(
1113
f, ::NoPullbackPrep, backend::AutoReverseChainRules, x, ty::Tangents
@@ -21,7 +23,7 @@ function DI.value_and_pullback(
2123
rc = ruleconfig(backend)
2224
y, pb = rrule_via_ad(rc, f, x)
2325
tx = map(ty) do dy
24-
last(pb(dy))
26+
pb(dy)[2]
2527
end
2628
return y, tx
2729
end
@@ -31,7 +33,7 @@ function DI.value_and_pullback(
3133
)
3234
@compat (; y, pb) = prep
3335
tx = map(ty) do dy
34-
last(pb(dy))
36+
pb(dy)[2]
3537
end
3638
return copy(y), tx
3739
end
@@ -41,7 +43,7 @@ function DI.pullback(
4143
)
4244
@compat (; pb) = prep
4345
tx = map(ty) do dy
44-
last(pb(dy))
46+
pb(dy)[2]
4547
end
4648
return tx
4749
end

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl

Lines changed: 114 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@ module DifferentiationInterfaceFiniteDifferencesExt
33
using ADTypes: AutoFiniteDifferences
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6-
NoGradientPrep, NoJacobianPrep, NoPullbackPrep, NoPushforwardPrep, Tangents
6+
Context,
7+
NoGradientPrep,
8+
NoJacobianPrep,
9+
NoPullbackPrep,
10+
NoPushforwardPrep,
11+
Tangents,
12+
unwrap,
13+
with_contexts
714
using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
815
using LinearAlgebra: dot
916

@@ -12,85 +19,158 @@ DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported()
1219

1320
## Pushforward
1421

15-
function DI.prepare_pushforward(f, ::AutoFiniteDifferences, x, tx::Tangents)
22+
function DI.prepare_pushforward(
23+
f, ::AutoFiniteDifferences, x, tx::Tangents, contexts::Vararg{Context,C}
24+
) where {C}
1625
return NoPushforwardPrep()
1726
end
1827

1928
function DI.pushforward(
20-
f, ::NoPushforwardPrep, backend::AutoFiniteDifferences, x, tx::Tangents
21-
)
29+
f,
30+
::NoPushforwardPrep,
31+
backend::AutoFiniteDifferences,
32+
x,
33+
tx::Tangents,
34+
contexts::Vararg{Context,C},
35+
) where {C}
36+
fc = with_contexts(f, contexts...)
2237
ty = map(tx) do dx
23-
jvp(backend.fdm, f, (x, dx))
38+
jvp(backend.fdm, fc, (x, dx))
2439
end
2540
return ty
2641
end
2742

2843
function DI.value_and_pushforward(
29-
f, prep::NoPushforwardPrep, backend::AutoFiniteDifferences, x, tx::Tangents
30-
)
31-
return f(x), DI.pushforward(f, prep, backend, x, tx)
44+
f,
45+
prep::NoPushforwardPrep,
46+
backend::AutoFiniteDifferences,
47+
x,
48+
tx::Tangents,
49+
contexts::Vararg{Context,C},
50+
) where {C}
51+
return f(x, map(unwrap, contexts)...),
52+
DI.pushforward(f, prep, backend, x, tx, contexts...)
3253
end
3354

3455
## Pullback
3556

36-
DI.prepare_pullback(f, ::AutoFiniteDifferences, x, ty::Tangents) = NoPullbackPrep()
57+
function DI.prepare_pullback(
58+
f, ::AutoFiniteDifferences, x, ty::Tangents, contexts::Vararg{Context,C}
59+
) where {C}
60+
return NoPullbackPrep()
61+
end
3762

38-
function DI.pullback(f, ::NoPullbackPrep, backend::AutoFiniteDifferences, x, ty::Tangents)
63+
function DI.pullback(
64+
f,
65+
::NoPullbackPrep,
66+
backend::AutoFiniteDifferences,
67+
x,
68+
ty::Tangents,
69+
contexts::Vararg{Context,C},
70+
) where {C}
71+
fc = with_contexts(f, contexts...)
3972
tx = map(ty) do dy
40-
only(j′vp(backend.fdm, f, dy, x))
73+
only(j′vp(backend.fdm, fc, dy, x))
4174
end
4275
return tx
4376
end
4477

4578
function DI.value_and_pullback(
46-
f, prep::NoPullbackPrep, backend::AutoFiniteDifferences, x, ty::Tangents
47-
)
48-
return f(x), DI.pullback(f, prep, backend, x, ty)
79+
f,
80+
prep::NoPullbackPrep,
81+
backend::AutoFiniteDifferences,
82+
x,
83+
ty::Tangents,
84+
contexts::Vararg{Context,C},
85+
) where {C}
86+
return f(x, map(unwrap, contexts)...), DI.pullback(f, prep, backend, x, ty, contexts...)
4987
end
5088

5189
## Gradient
5290

53-
DI.prepare_gradient(f, ::AutoFiniteDifferences, x) = NoGradientPrep()
91+
function DI.prepare_gradient(
92+
f, ::AutoFiniteDifferences, x, contexts::Vararg{Context,C}
93+
) where {C}
94+
return NoGradientPrep()
95+
end
5496

55-
function DI.gradient(f, ::NoGradientPrep, backend::AutoFiniteDifferences, x)
56-
return only(grad(backend.fdm, f, x))
97+
function DI.gradient(
98+
f, ::NoGradientPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{Context,C}
99+
) where {C}
100+
fc = with_contexts(f, contexts...)
101+
return only(grad(backend.fdm, fc, x))
57102
end
58103

59-
function DI.value_and_gradient(f, prep::NoGradientPrep, backend::AutoFiniteDifferences, x)
60-
return f(x), DI.gradient(f, prep, backend, x)
104+
function DI.value_and_gradient(
105+
f, prep::NoGradientPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{Context,C}
106+
) where {C}
107+
return f(x, map(unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...)
61108
end
62109

63-
function DI.gradient!(f, grad, prep::NoGradientPrep, backend::AutoFiniteDifferences, x)
64-
return copyto!(grad, DI.gradient(f, prep, backend, x))
110+
function DI.gradient!(
111+
f,
112+
grad,
113+
prep::NoGradientPrep,
114+
backend::AutoFiniteDifferences,
115+
x,
116+
contexts::Vararg{Context,C},
117+
) where {C}
118+
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
65119
end
66120

67121
function DI.value_and_gradient!(
68-
f, grad, prep::NoGradientPrep, backend::AutoFiniteDifferences, x
69-
)
70-
y, new_grad = DI.value_and_gradient(f, prep, backend, x)
122+
f,
123+
grad,
124+
prep::NoGradientPrep,
125+
backend::AutoFiniteDifferences,
126+
x,
127+
contexts::Vararg{Context,C},
128+
) where {C}
129+
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
71130
return y, copyto!(grad, new_grad)
72131
end
73132

74133
## Jacobian
75134

76-
DI.prepare_jacobian(f, ::AutoFiniteDifferences, x) = NoJacobianPrep()
135+
function DI.prepare_jacobian(
136+
f, ::AutoFiniteDifferences, x, contexts::Vararg{Context,C}
137+
) where {C}
138+
return NoJacobianPrep()
139+
end
77140

78-
function DI.jacobian(f, ::NoJacobianPrep, backend::AutoFiniteDifferences, x)
79-
return only(jacobian(backend.fdm, f, x))
141+
function DI.jacobian(
142+
f, ::NoJacobianPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{Context,C}
143+
) where {C}
144+
fc = with_contexts(f, contexts...)
145+
return only(jacobian(backend.fdm, fc, x))
80146
end
81147

82-
function DI.value_and_jacobian(f, prep::NoJacobianPrep, backend::AutoFiniteDifferences, x)
83-
return f(x), DI.jacobian(f, prep, backend, x)
148+
function DI.value_and_jacobian(
149+
f, prep::NoJacobianPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{Context,C}
150+
) where {C}
151+
return f(x, map(unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...)
84152
end
85153

86-
function DI.jacobian!(f, jac, prep::NoJacobianPrep, backend::AutoFiniteDifferences, x)
87-
return copyto!(jac, DI.jacobian(f, prep, backend, x))
154+
function DI.jacobian!(
155+
f,
156+
jac,
157+
prep::NoJacobianPrep,
158+
backend::AutoFiniteDifferences,
159+
x,
160+
contexts::Vararg{Context,C},
161+
) where {C}
162+
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
88163
end
89164

90165
function DI.value_and_jacobian!(
91-
f, jac, prep::NoJacobianPrep, backend::AutoFiniteDifferences, x
92-
)
93-
y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
166+
f,
167+
jac,
168+
prep::NoJacobianPrep,
169+
backend::AutoFiniteDifferences,
170+
x,
171+
contexts::Vararg{Context,C},
172+
) where {C}
173+
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
94174
return y, copyto!(jac, new_jac)
95175
end
96176

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ using DifferentiationInterface:
1515
NoDerivativePrep,
1616
NoSecondDerivativePrep,
1717
PushforwardPrep,
18+
Rewrap,
1819
SecondOrder,
1920
Tangents,
2021
inner,
2122
outer,
22-
unwrap
23+
unwrap,
24+
with_contexts
2325
using ForwardDiff.DiffResults:
2426
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
2527
using ForwardDiff:

0 commit comments

Comments
 (0)