Skip to content

Commit d53f8ac

Browse files
authored
chore: get rid of implicit imports and clarify extension imports (#649)
* Remove imports from DI in extensions * Add DI prefix everywhere * Unwrap * Typos * Typos * Context * Inner outer * Typos * Basis * Explicit imports from DI in DIT * Typos * Gradient and hvp * Typos * Typo * Typos * Typos * Typos * Remove implicit imports in DIT, add tests * Relu * Typo * DIT * Retoggle test on 1.11 * Not broken * Public tests on 1.11 * Bump
1 parent 8fe1dd1 commit d53f8ac

68 files changed

Lines changed: 1116 additions & 1011 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99

1010
[weakdeps]
1111
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
12+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1213
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
1314
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1415
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
@@ -32,10 +33,10 @@ DifferentiationInterfaceEnzymeExt = "Enzyme"
3233
DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3334
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3435
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
35-
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
36+
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
3637
DifferentiationInterfaceMooncakeExt = "Mooncake"
3738
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
38-
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
39+
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
3940
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
4041
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
4142
DifferentiationInterfaceStaticArraysExt = "StaticArrays"

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ using ChainRulesCore:
1010
frule_via_ad,
1111
rrule_via_ad
1212
import DifferentiationInterface as DI
13-
using DifferentiationInterface:
14-
Constant, DifferentiateWith, NoPullbackPrep, NoPushforwardPrep, PullbackPrep, unwrap
1513

1614
ruleconfig(backend::AutoChainRules) = backend.ruleconfig
1715

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function ChainRulesCore.rrule(dw::DifferentiateWith, x)
1+
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
22
(; f, backend) = dw
33
y = f(x)
44
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,))

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
## Pullback
22

3-
struct ChainRulesPullbackPrepSamePoint{Y,PB} <: PullbackPrep
3+
struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
44
y::Y
55
pb::PB
66
end
77

88
function DI.prepare_pullback(
9-
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{Constant,C}
9+
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.Constant,C}
1010
) where {C}
11-
return NoPullbackPrep()
11+
return DI.NoPullbackPrep()
1212
end
1313

1414
function DI.prepare_pullback_same_point(
1515
f,
16-
::NoPullbackPrep,
16+
::DI.NoPullbackPrep,
1717
backend::AutoReverseChainRules,
1818
x,
1919
ty::NTuple,
20-
contexts::Vararg{Constant,C},
20+
contexts::Vararg{DI.Constant,C},
2121
) where {C}
2222
rc = ruleconfig(backend)
23-
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
23+
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
2424
return ChainRulesPullbackPrepSamePoint(y, pb)
2525
end
2626

2727
function DI.value_and_pullback(
2828
f,
29-
::NoPullbackPrep,
29+
::DI.NoPullbackPrep,
3030
backend::AutoReverseChainRules,
3131
x,
3232
ty::NTuple,
33-
contexts::Vararg{Constant,C},
33+
contexts::Vararg{DI.Constant,C},
3434
) where {C}
3535
rc = ruleconfig(backend)
36-
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
36+
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
3737
tx = map(ty) do dy
3838
pb(dy)[2]
3939
end
@@ -46,7 +46,7 @@ function DI.value_and_pullback(
4646
::AutoReverseChainRules,
4747
x,
4848
ty::NTuple,
49-
contexts::Vararg{Constant,C},
49+
contexts::Vararg{DI.Constant,C},
5050
) where {C}
5151
(; y, pb) = prep
5252
tx = map(ty) do dy
@@ -61,7 +61,7 @@ function DI.pullback(
6161
::AutoReverseChainRules,
6262
x,
6363
ty::NTuple,
64-
contexts::Vararg{Constant,C},
64+
contexts::Vararg{DI.Constant,C},
6565
) where {C}
6666
(; pb) = prep
6767
tx = map(ty) do dy

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module DifferentiationInterfaceDiffractorExt
22

33
using ADTypes: ADTypes, AutoDiffractor
44
import DifferentiationInterface as DI
5-
using DifferentiationInterface: NoPushforwardPrep
65
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆
76

87
DI.check_available(::AutoDiffractor) = true
@@ -11,9 +10,9 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1110

1211
## Pushforward
1312

14-
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = NoPushforwardPrep()
13+
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep()
1514

16-
function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
15+
function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
1716
ty = map(tx) do dx
1817
# code copied from Diffractor.jl
1918
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
@@ -24,7 +23,7 @@ function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
2423
end
2524

2625
function DI.value_and_pushforward(
27-
f, prep::NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
26+
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
2827
)
2928
return f(x), DI.pushforward(f, prep, backend, x, tx)
3029
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,6 @@ module DifferentiationInterfaceEnzymeExt
33
using ADTypes: ADTypes, AutoEnzyme
44
using Base: Fix1
55
import DifferentiationInterface as DI
6-
using DifferentiationInterface:
7-
Context,
8-
DerivativePrep,
9-
GradientPrep,
10-
JacobianPrep,
11-
HVPPrep,
12-
PullbackPrep,
13-
PushforwardPrep,
14-
NoDerivativePrep,
15-
NoGradientPrep,
16-
NoHVPPrep,
17-
NoJacobianPrep,
18-
NoPullbackPrep,
19-
NoPushforwardPrep
206
using Enzyme:
217
Active,
228
Annotation,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ function DI.prepare_pushforward(
55
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
66
x,
77
tx::NTuple,
8-
contexts::Vararg{Context,C},
8+
contexts::Vararg{DI.Context,C},
99
) where {F,C}
10-
return NoPushforwardPrep()
10+
return DI.NoPushforwardPrep()
1111
end
1212

1313
function DI.value_and_pushforward(
1414
f::F,
15-
::NoPushforwardPrep,
15+
::DI.NoPushforwardPrep,
1616
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1717
x,
1818
tx::NTuple{1},
19-
contexts::Vararg{Context,C},
19+
contexts::Vararg{DI.Context,C},
2020
) where {F,C}
2121
f_and_df = get_f_and_df(f, backend)
2222
dx_sametype = convert(typeof(x), only(tx))
@@ -29,11 +29,11 @@ end
2929

3030
function DI.value_and_pushforward(
3131
f::F,
32-
::NoPushforwardPrep,
32+
::DI.NoPushforwardPrep,
3333
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
3434
x,
3535
tx::NTuple{B},
36-
contexts::Vararg{Context,C},
36+
contexts::Vararg{DI.Context,C},
3737
) where {F,B,C}
3838
f_and_df = get_f_and_df(f, backend, Val(B))
3939
tx_sametype = map(Fix1(convert, typeof(x)), tx)
@@ -46,11 +46,11 @@ end
4646

4747
function DI.pushforward(
4848
f::F,
49-
::NoPushforwardPrep,
49+
::DI.NoPushforwardPrep,
5050
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
5151
x,
5252
tx::NTuple{1},
53-
contexts::Vararg{Context,C},
53+
contexts::Vararg{DI.Context,C},
5454
) where {F,C}
5555
f_and_df = get_f_and_df(f, backend)
5656
dx_sametype = convert(typeof(x), only(tx))
@@ -63,11 +63,11 @@ end
6363

6464
function DI.pushforward(
6565
f::F,
66-
::NoPushforwardPrep,
66+
::DI.NoPushforwardPrep,
6767
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6868
x,
6969
tx::NTuple{B},
70-
contexts::Vararg{Context,C},
70+
contexts::Vararg{DI.Context,C},
7171
) where {F,B,C}
7272
f_and_df = get_f_and_df(f, backend, Val(B))
7373
tx_sametype = map(Fix1(convert, typeof(x)), tx)
@@ -81,11 +81,11 @@ end
8181
function DI.value_and_pushforward!(
8282
f::F,
8383
ty::NTuple,
84-
prep::NoPushforwardPrep,
84+
prep::DI.NoPushforwardPrep,
8585
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
8686
x,
8787
tx::NTuple,
88-
contexts::Vararg{Context,C},
88+
contexts::Vararg{DI.Context,C},
8989
) where {F,C}
9090
# dy cannot be passed anyway
9191
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
@@ -96,11 +96,11 @@ end
9696
function DI.pushforward!(
9797
f::F,
9898
ty::NTuple,
99-
prep::NoPushforwardPrep,
99+
prep::DI.NoPushforwardPrep,
100100
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
101101
x,
102102
tx::NTuple,
103-
contexts::Vararg{Context,C},
103+
contexts::Vararg{DI.Context,C},
104104
) where {F,C}
105105
# dy cannot be passed anyway
106106
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
@@ -110,7 +110,7 @@ end
110110

111111
## Gradient
112112

113-
struct EnzymeForwardGradientPrep{B,O} <: GradientPrep
113+
struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
114114
shadows::O
115115
end
116116

@@ -175,7 +175,7 @@ end
175175

176176
## Jacobian
177177

178-
struct EnzymeForwardOneArgJacobianPrep{B,O} <: JacobianPrep
178+
struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
179179
shadows::O
180180
output_length::Int
181181
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@ function DI.prepare_pushforward(
66
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
77
x,
88
tx::NTuple,
9-
contexts::Vararg{Context,C},
9+
contexts::Vararg{DI.Context,C},
1010
) where {F,C}
11-
return NoPushforwardPrep()
11+
return DI.NoPushforwardPrep()
1212
end
1313

1414
function DI.value_and_pushforward(
1515
f!::F,
1616
y,
17-
::NoPushforwardPrep,
17+
::DI.NoPushforwardPrep,
1818
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1919
x,
2020
tx::NTuple{1},
21-
contexts::Vararg{Context,C},
21+
contexts::Vararg{DI.Context,C},
2222
) where {F,C}
2323
f!_and_df! = get_f_and_df(f!, backend)
2424
dx_sametype = convert(typeof(x), only(tx))
@@ -39,11 +39,11 @@ end
3939
function DI.value_and_pushforward(
4040
f!::F,
4141
y,
42-
::NoPushforwardPrep,
42+
::DI.NoPushforwardPrep,
4343
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
4444
x,
4545
tx::NTuple{B},
46-
contexts::Vararg{Context,C},
46+
contexts::Vararg{DI.Context,C},
4747
) where {F,B,C}
4848
f!_and_df! = get_f_and_df(f!, backend, Val(B))
4949
tx_sametype = map(Fix1(convert, typeof(x)), tx)
@@ -64,11 +64,11 @@ end
6464
function DI.pushforward(
6565
f!::F,
6666
y,
67-
prep::NoPushforwardPrep,
67+
prep::DI.NoPushforwardPrep,
6868
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6969
x,
7070
tx::NTuple,
71-
contexts::Vararg{Context,C},
71+
contexts::Vararg{DI.Context,C},
7272
) where {F,C}
7373
_, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
7474
return ty
@@ -78,11 +78,11 @@ function DI.value_and_pushforward!(
7878
f!::F,
7979
y,
8080
ty::NTuple,
81-
prep::NoPushforwardPrep,
81+
prep::DI.NoPushforwardPrep,
8282
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
8383
x,
8484
tx::NTuple,
85-
contexts::Vararg{Context,C},
85+
contexts::Vararg{DI.Context,C},
8686
) where {F,C}
8787
y, new_ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
8888
foreach(copyto!, ty, new_ty)
@@ -93,11 +93,11 @@ function DI.pushforward!(
9393
f!::F,
9494
y,
9595
ty::NTuple,
96-
prep::NoPushforwardPrep,
96+
prep::DI.NoPushforwardPrep,
9797
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
9898
x,
9999
tx::NTuple,
100-
contexts::Vararg{Context,C},
100+
contexts::Vararg{DI.Context,C},
101101
) where {F,C}
102102
new_ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...)
103103
foreach(copyto!, ty, new_ty)

0 commit comments

Comments
 (0)