Skip to content

Commit 4cfd3d0

Browse files
authored
Introduce type hierarchy for extras (#127)
* Introduce type hierarchy for extras * Fix docs
1 parent 3bfcd1f commit 4cfd3d0

49 files changed

Lines changed: 1011 additions & 389 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/src/core.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,47 @@ DifferentiationInterface
1414
```@autodocs
1515
Modules = [DifferentiationInterface]
1616
Pages = ["src/derivative.jl"]
17+
Private = false
1718
```
1819

1920
## Gradient
2021

2122
```@autodocs
2223
Modules = [DifferentiationInterface]
2324
Pages = ["gradient.jl"]
25+
Private = false
2426
```
2527

2628
## Jacobian
2729

2830
```@autodocs
2931
Modules = [DifferentiationInterface]
3032
Pages = ["jacobian.jl"]
33+
Private = false
3134
```
3235

3336
## Second order
3437

3538
```@autodocs
3639
Modules = [DifferentiationInterface]
3740
Pages = ["second_order.jl", "second_derivative.jl", "hessian.jl", "hvp.jl"]
41+
Private = false
3842
```
3943

4044
## Primitives
4145

4246
```@autodocs
4347
Modules = [DifferentiationInterface]
4448
Pages = ["pushforward.jl", "pullback.jl"]
49+
Private = false
4550
```
4651

4752
## Backend queries
4853

4954
```@autodocs
5055
Modules = [DifferentiationInterface]
5156
Pages = ["backends.jl"]
57+
Private = false
5258
```
5359

5460
## Internals
@@ -58,6 +64,4 @@ This is not part of the public API.
5864
```@autodocs
5965
Modules = [DifferentiationInterface]
6066
Public = false
61-
Order = [:function, :type]
62-
Filter = t -> !(t isa Type && t <: ADTypes.AbstractADType)
6367
```

ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ADTypes: ADTypes, AutoChainRules
44
using ChainRulesCore:
55
HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad
66
import DifferentiationInterface as DI
7+
using DifferentiationInterface: NoPullbackExtras, NoPushforwardExtras
78

89
ruleconfig(backend::AutoChainRules) = backend.ruleconfig
910

@@ -14,15 +15,23 @@ DI.supports_mutation(::AutoChainRules) = DI.MutationNotSupported()
1415
DI.mode(::AutoForwardChainRules) = ADTypes.AbstractForwardMode
1516
DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode
1617

17-
## Primitives
18+
## Pushforward
1819

19-
function DI.value_and_pushforward(f, backend::AutoForwardChainRules, x, dx, extras::Nothing)
20+
DI.prepare_pushforward(f, ::AutoForwardChainRules, x) = NoPushforwardExtras()
21+
22+
function DI.value_and_pushforward(
23+
f, backend::AutoForwardChainRules, x, dx, ::NoPushforwardExtras
24+
)
2025
rc = ruleconfig(backend)
2126
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
2227
return y, new_dy
2328
end
2429

25-
function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, extras::Nothing)
30+
## Pullback
31+
32+
DI.prepare_pullback(f, ::AutoForwardChainRules, x) = NoPullbackExtras()
33+
34+
function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, ::NoPullbackExtras)
2635
rc = ruleconfig(backend)
2736
y, pullback = rrule_via_ad(rc, f, x)
2837
_, new_dx = pullback(dy)

ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@ module DifferentiationInterfaceDiffractorExt
33
import AbstractDifferentiation as AD # public API for Diffractor
44
using ADTypes: ADTypes, AutoChainRules, AutoDiffractor
55
import DifferentiationInterface as DI
6+
using DifferentiationInterface: NoPushforwardExtras
67
using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig
78

89
DI.supports_mutation(::AutoDiffractor) = DI.MutationNotSupported()
910
DI.mode(::AutoDiffractor) = ADTypes.AbstractForwardMode
1011
DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = ADTypes.AbstractForwardMode
1112

12-
function DI.value_and_pushforward(f, ::AutoDiffractor, x, dx, extras::Nothing)
13+
## Pushforward
14+
15+
DI.prepare_pushforward(f, ::AutoDiffractor, x) = NoPushforwardExtras()
16+
17+
function DI.value_and_pushforward(f, ::AutoDiffractor, x, dx, ::NoPushforwardExtras)
1318
vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x)
1419
y, dy = vpff((dx,))
1520
return y, dy

ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@ module DifferentiationInterfaceEnzymeExt
22

33
using ADTypes: ADTypes, AutoEnzyme
44
import DifferentiationInterface as DI
5+
using DifferentiationInterface:
6+
NoDerivativeExtras,
7+
NoGradientExtras,
8+
NoJacobianExtras,
9+
NoPullbackExtras,
10+
NoPushforwardExtras
511
using DocStringExtensions
612
using Enzyme:
713
Active,
Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,86 @@
11
## Pushforward
22

3-
function DI.value_and_pushforward(f, backend::AutoForwardEnzyme, x, dx, extras::Nothing)
3+
DI.prepare_pushforward(f, ::AutoForwardEnzyme, x) = NoPushforwardExtras()
4+
5+
function DI.value_and_pushforward(
6+
f, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras
7+
)
48
dx_sametype = convert(typeof(x), dx)
59
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
610
return y, new_dy
711
end
812

9-
function DI.pushforward(f, backend::AutoForwardEnzyme, x, dx, extras::Nothing)
13+
function DI.pushforward(f, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras)
1014
dx_sametype = convert(typeof(x), dx)
1115
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype)))
1216
return new_dy
1317
end
1418

1519
function DI.value_and_pushforward!!(
16-
f, _dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing
20+
f, _dy, backend::AutoForwardEnzyme, x, dx, extras::NoPushforwardExtras
1721
)
1822
# dy cannot be passed anyway
1923
return DI.value_and_pushforward(f, backend, x, dx, extras)
2024
end
2125

22-
function DI.pushforward!!(f, _dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing)
26+
function DI.pushforward!!(
27+
f, _dy, backend::AutoForwardEnzyme, x, dx, extras::NoPushforwardExtras
28+
)
2329
# dy cannot be passed anyway
2430
return DI.pushforward(f, backend, x, dx, extras)
2531
end
2632

2733
## Gradient
2834

29-
function DI.gradient(f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing)
35+
DI.prepare_gradient(f, ::AutoForwardEnzyme, x) = NoGradientExtras()
36+
37+
function DI.gradient(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoGradientExtras)
3038
return reshape(collect(gradient(backend.mode, f, x)), size(x))
3139
end
3240

3341
function DI.value_and_gradient(
34-
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
42+
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras
3543
)
3644
return f(x), DI.gradient(f, backend, x, extras)
3745
end
3846

3947
function DI.gradient!!(
40-
f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
48+
f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras
4149
)
4250
return DI.gradient(f, backend, x, extras)
4351
end
4452

4553
function DI.value_and_gradient!!(
46-
f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
54+
f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras
4755
)
4856
return DI.value_and_gradient(f, backend, x, extras)
4957
end
5058

5159
## Jacobian
5260

53-
function DI.jacobian(f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing)
61+
DI.prepare_jacobian(f, ::AutoForwardEnzyme, x) = NoJacobianExtras()
62+
63+
function DI.jacobian(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoJacobianExtras)
5464
jac_wrongshape = jacobian(backend.mode, f, x)
5565
nx = length(x)
5666
ny = length(jac_wrongshape) ÷ length(x)
5767
return reshape(jac_wrongshape, ny, nx)
5868
end
5969

6070
function DI.value_and_jacobian(
61-
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
71+
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras
6272
)
6373
return f(x), DI.jacobian(f, backend, x, extras)
6474
end
6575

6676
function DI.jacobian!!(
67-
f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
77+
f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras
6878
)
6979
return DI.jacobian(f, backend, x, extras)
7080
end
7181

7282
function DI.value_and_jacobian!!(
73-
f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
83+
f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras
7484
)
7585
return DI.value_and_jacobian(f, backend, x, extras)
7686
end

ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
## Pushforward
22

3+
DI.prepare_pushforward(f!, ::AutoForwardEnzyme, y, x) = NoPushforwardExtras()
4+
35
function DI.value_and_pushforward!!(
4-
f!, y, dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing
6+
f!, y, dy, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras
57
)
68
dx_sametype = convert(typeof(x), dx)
79
dy_sametype = zero_sametype!!(dy, y)

ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
## Pullback
22

3+
DI.prepare_pullback(f, ::AutoReverseEnzyme, x) = NoPullbackExtras()
4+
35
### Out-of-place
46

57
function DI.value_and_pullback(
6-
f, ::AutoReverseEnzyme, x::Number, dy::Number, extras::Nothing
8+
f, ::AutoReverseEnzyme, x::Number, dy::Number, ::NoPullbackExtras
79
)
810
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
911
new_dx = dy * only(der)
1012
return y, new_dx
1113
end
1214

1315
function DI.value_and_pullback(
14-
f, ::AutoReverseEnzyme, x::Number, dy::AbstractArray, extras::Nothing
16+
f, ::AutoReverseEnzyme, x::Number, dy::AbstractArray, ::NoPullbackExtras
1517
)
1618
forw, rev = autodiff_thunk(
1719
ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, Active{typeof(x)}
@@ -25,7 +27,7 @@ end
2527
### In-place
2628

2729
function DI.value_and_pullback!!(
28-
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::Number, extras::Nothing
30+
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras
2931
)
3032
dx_sametype = zero_sametype!!(dx, x)
3133
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
@@ -34,7 +36,7 @@ function DI.value_and_pullback!!(
3436
end
3537

3638
function DI.value_and_pullback!!(
37-
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::AbstractArray, extras::Nothing
39+
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::AbstractArray, ::NoPullbackExtras
3840
)
3941
dx_sametype = zero_sametype!!(dx, x)
4042
forw, rev = autodiff_thunk(
@@ -46,18 +48,22 @@ function DI.value_and_pullback!!(
4648
return y, dx_sametype
4749
end
4850

49-
function DI.value_and_pullback(f, backend::AutoReverseEnzyme, x::AbstractArray, dy, extras)
51+
function DI.value_and_pullback(
52+
f, backend::AutoReverseEnzyme, x::AbstractArray, dy, extras::NoPullbackExtras
53+
)
5054
dx = similar(x)
5155
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)
5256
end
5357

5458
## Gradient
5559

56-
function DI.gradient(f, ::AutoReverseEnzyme, x::AbstractArray, extras::Nothing)
60+
DI.prepare_gradient(f, ::AutoReverseEnzyme) = NoGradientExtras()
61+
62+
function DI.gradient(f, ::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras)
5763
return gradient(Reverse, f, x)
5864
end
5965

60-
function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x::AbstractArray, extras::Nothing)
66+
function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras)
6167
grad_sametype = convert(typeof(x), grad)
6268
gradient!(Reverse, grad_sametype, f, x)
6369
return grad_sametype

ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
## Pullback
22

3+
DI.prepare_pullback(f!, ::AutoReverseEnzyme, y, x) = NoPullbackExtras()
4+
35
function DI.value_and_pullback!!(
4-
f!, y, _dx, ::AutoReverseEnzyme, x::Number, dy, extras::Nothing
6+
f!, y, _dx, ::AutoReverseEnzyme, x::Number, dy, ::NoPullbackExtras
57
)
68
dy_sametype = convert(typeof(y), copy(dy))
79
_, new_dx = only(autodiff(Reverse, f!, Const, Duplicated(y, dy_sametype), Active(x)))
810
return y, new_dx
911
end
1012

1113
function DI.value_and_pullback!!(
12-
f!, y, dx, ::AutoReverseEnzyme, x::AbstractArray, dy, extras::Nothing
14+
f!, y, dx, ::AutoReverseEnzyme, x::AbstractArray, dy, ::NoPullbackExtras
1315
)
1416
dx_sametype = zero_sametype!!(dx, x)
1517
dy_sametype = convert(typeof(y), copy(dy))

ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
module DifferentiationInterfaceFastDifferentiationExt
22

33
using ADTypes: ADTypes
4-
using DifferentiationInterface: AutoFastDifferentiation, AutoSparseFastDifferentiation
54
import DifferentiationInterface as DI
5+
using DifferentiationInterface: AutoFastDifferentiation, AutoSparseFastDifferentiation
6+
using DifferentiationInterface:
7+
DerivativeExtras,
8+
GradientExtras,
9+
HessianExtras,
10+
HVPExtras,
11+
JacobianExtras,
12+
PullbackExtras,
13+
PushforwardExtras
614
using FastDifferentiation:
715
derivative,
816
hessian,

0 commit comments

Comments
 (0)